import time
import warnings
from collections import deque
from enum import IntEnum
import numpy as np
from numpy import array
from recordtype import recordtype
from flatland.utils.graphics_pgl import PGLGL
from flatland.utils.graphics_pil import PILGL, PILSVG
# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
[docs]
class AgentRenderVariant(IntEnum):
    BOX_ONLY = 0
    ONE_STEP_BEHIND = 1
    AGENT_SHOWS_OPTIONS = 2
    ONE_STEP_BEHIND_AND_BOX = 3
    AGENT_SHOWS_OPTIONS_AND_BOX = 4 
[docs]
class RenderBase(object):
    def __init__(self, env):
        pass
[docs]
    def render_env(self):
        pass 
[docs]
    def close_window(self):
        pass 
[docs]
    def set_new_rail(self):
        """ Signal to the renderer that the env has changed and will need re-rendering.
        """
        pass 
[docs]
    def update_background(self):
        """ A lesser version of set_new_rail?
            TODO: can update_background be pruned for simplicity?
        """
        pass 
 
[docs]
class RenderLocal(RenderBase):
    """ Class to render the RailEnv and agents.
        Uses two layers, layer 0 for rails (mostly static), layer 1 for agents etc (dynamic)
        The lower / rail layer 0 is only redrawn after set_new_rail() has been called.
        Created with a "GraphicsLayer" or gl - now either PIL or PILSVG
    """
    visit = recordtype("visit", ["rc", "iDir", "iDepth", "prev"])
    color_list = list("brgcmyk")
    # \delta RC for NESW
    transitions_row_col = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
    pix_per_cell = 1  # misnomer...
    half_pix_per_cell = pix_per_cell / 2
    x_y_half = array([half_pix_per_cell, -half_pix_per_cell])
    row_col_to_xy = array([[0, -pix_per_cell], [pix_per_cell, 0]])
    grid = array(np.meshgrid(np.arange(10), -np.arange(10))) * array([[[pix_per_cell]], [[pix_per_cell]]])
    theta = np.linspace(0, np.pi / 2, 5)
    arc = array([np.cos(theta), np.sin(theta)]).T  # from [1,0] to [0,1]
    def __init__(self, env, gl="PILSVG", jupyter=False,
                 agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
                 show_debug=False, clear_debug_text=True, screen_width=800, screen_height=600):
        self.env = env
        self.frame_nr = 0
        self.start_time = time.time()
        self.times_list = deque()
        self.agent_render_variant = agent_render_variant
        self.gl_str = gl
        if gl == "PIL":
            self.gl = PILGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
        elif gl == "PILSVG":
            self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
        else:
            if gl != "PGL":
                print("[", gl, "] not found, switch to PGL, PILSVG")
                print("Using PGL")
            self.gl = PGLGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
        self.new_rail = True
        self.show_debug = show_debug
        self.clear_debug_text = clear_debug_text
        self.update_background()
[docs]
    def reset(self):
        """
        Resets the environment
        :return:
        """
        self.set_new_rail()
        self.frame_nr = 0
        self.start_time = time.time()
        self.times_list = deque()
        return 
[docs]
    def update_background(self):
        # create background map
        targets = {}
        for agent_idx, agent in enumerate(self.env.agents):
            if agent is None:
                continue
            #print(f"updatebg: {agent_idx} {agent.target}")
            targets[tuple(agent.target)] = agent_idx
        self.gl.build_background_map(targets) 
[docs]
    def resize(self):
        self.gl.resize(self.env) 
[docs]
    def set_new_rail(self):
        """ Tell the renderer that the rail has changed.
            eg when the rail has been regenerated, or updated in the editor.
        """
        self.new_rail = True 
[docs]
    def plot_agents(self, targets=True, selected_agent=None):
        color_map = self.gl.get_cmap('hsv', lut=(len(self.env.agents) + 1))
        for agent_idx, agent in enumerate(self.env.agents):
            if agent is None:
                continue
            color = color_map(agent_idx)
            self.plot_single_agent(agent.position, agent.direction, color, target=agent.target if targets else None,
                                   static=True, selected=agent_idx == selected_agent)
        for agent_idx, agent in enumerate(self.env.agents):
            if agent is None:
                continue
            color = color_map(agent_idx)
            self.plot_single_agent(agent.position, agent.direction, color, target=agent.target if targets else None) 
[docs]
    def get_transition_row_col(self, row_col_pos, direction, bgiTrans=False):
        """
        Get the available transitions for row_col_pos in direction direction,
        as row & col deltas.
        If bgiTrans is True, return a grid of indices of available transitions.
        eg for a cell row_col_pos = (4,5), in direction direction = 0 (N),
        where the available transitions are N and E, returns:
        [[-1,0], [0,1]] ie N=up one row, and E=right one col.
        and if bgiTrans is True, returns a tuple:
        (
            [[-1,0], [0,1]], # deltas as before
            [0, 1] #  available transition indices, ie N, E
        )
        """
        transitions = self.env.rail.get_transitions((row_col_pos, direction))
        transition_list = np.where(transitions)[0]  # RC list of transitions
        # HACK: workaround dead-end transitions
        if len(transition_list) == 0:
            reverse_direciton = (direction + 2) % 4
            transitions = tuple(int(tmp_dir == reverse_direciton) for tmp_dir in range(4))
            transition_list = np.where(transitions)[0]  # RC list of transitions
        transition_grid = self.__class__.transitions_row_col[transition_list]
        if bgiTrans:
            return transition_grid, transition_list
        else:
            return transition_grid 
[docs]
    def plot_single_agent(self, position_row_col, direction, color="r", target=None, static=False, selected=False):
        """
        Plot a simple agent.
        Assumes a working graphics layer context (cf a MPL figure).
        """
        if position_row_col is None:
            return
        rt = self.__class__
        direction_row_col = rt.transitions_row_col[direction]  # agent direction in RC
        direction_xy = np.matmul(direction_row_col, rt.row_col_to_xy)  # agent direction in xy
        xyPos = np.matmul(position_row_col - direction_row_col / 2, rt.row_col_to_xy) + rt.x_y_half
        if static:
            color = self.gl.adapt_color(color, lighten=True)
        color = color
        self.gl.scatter(*xyPos, color=color, layer=1, marker="o", s=100)  # agent location
        xy_dir_line = array([xyPos, xyPos + direction_xy / 2]).T  # line for agent orient.
        self.gl.plot(*xy_dir_line, color=color, layer=1, lw=5, ms=0, alpha=0.6)
        if selected:
            self._draw_square(xyPos, 1, color)
        if target is not None:
            target_row_col = array(target)
            target_xy = np.matmul(target_row_col, rt.row_col_to_xy) + rt.x_y_half
            self._draw_square(target_xy, 1 / 3, color, layer=1) 
[docs]
    def plot_transition(self, position_row_col, transition_row_col, color="r", depth=None):
        """
        plot the transitions in transition_row_col at position position_row_col.
        transition_row_col is a 2d numpy array containing a list of RC transitions,
        eg [[-1,0], [0,1]] means N, E.
        """
        rt = self.__class__
        position_xy = np.matmul(position_row_col, rt.row_col_to_xy) + rt.x_y_half
        transition_xy = position_xy + np.matmul(transition_row_col, rt.row_col_to_xy / 2.4)
        self.gl.scatter(*transition_xy.T, color=color, marker="o", s=50, alpha=0.2)
        if depth is not None:
            for x, y in transition_xy:
                self.gl.text(x, y, depth) 
[docs]
    def draw_transition(self,
                        line,
                        center,
                        rotation,
                        dead_end=False,
                        curves=False,
                        color="gray",
                        arrow=True,
                        spacing=0.1):
        """
        gLine is a numpy 2d array of points,
        in the plotting space / coords.
        eg:
        [[0,.5],[1,0.2]] means a line
        from x=0, y=0.5
        to   x=1, y=0.2
        """
        if not curves and not dead_end:
            # just a straigt line, no curve nor dead_end included in this basic rail element
            self.gl.plot(
                [line[0][0], line[1][0]],  # x
                [line[0][1], line[1][1]],  # y
                color=color
            )
        else:
            # it was not a simple line to draw: the rail has a curve or dead_end included.
            rt = self.__class__
            straight = rotation in [0, 2]
            dx, dy = np.squeeze(np.diff(line, axis=0)) * spacing / 2
            if straight:
                if color == "auto":
                    if dx > 0 or dy > 0:
                        color = "C1"  # N or E
                    else:
                        color = "C2"  # S or W
                if dead_end:
                    line_xy = array([
                        line[1] + [dy, dx],
                        center,
                        line[1] - [dy, dx],
                    ])
                    self.gl.plot(*line_xy.T, color=color)
                else:
                    line_xy = line + [-dy, dx]
                    self.gl.plot(*line_xy.T, color=color)
                    if arrow:
                        middle_xy = np.sum(line_xy * [[1 / 4], [3 / 4]], axis=0)
                        arrow_xy = array([
                            middle_xy + [-dx - dy, +dx - dy],
                            middle_xy,
                            middle_xy + [-dx + dy, -dx - dy]])
                        self.gl.plot(*arrow_xy.T, color=color)
            else:
                middle_xy = np.mean(line, axis=0)
                dxy = middle_xy - center
                corner = middle_xy + dxy
                if rotation == 1:
                    arc_factor = 1 - spacing
                    color_auto = "C1"
                else:
                    arc_factor = 1 + spacing
                    color_auto = "C2"
                dxy2 = (center - corner) * arc_factor  # for scaling the arc
                if color == "auto":
                    color = color_auto
                self.gl.plot(*(rt.arc * dxy2 + corner).T, color=color)
                if arrow:
                    dx, dy = np.squeeze(np.diff(line, axis=0)) / 20
                    iArc = int(len(rt.arc) / 2)
                    middle_xy = corner + rt.arc[iArc] * dxy2
                    arrow_xy = array([
                        middle_xy + [-dx - dy, +dx - dy],
                        middle_xy,
                        middle_xy + [-dx + dy, -dx - dy]])
                    self.gl.plot(*arrow_xy.T, color=color) 
[docs]
    def render_observation(self, agent_handles, observation_dict):
        """
        Render the extent of the observation of each agent. All cells that appear in the agent
        observation will be highlighted.
        :param agent_handles: List of agent indices to adapt color and get correct observation
        :param observation_dict: dictionary containing sets of cells of the agent observation
        """
        rt = self.__class__
        # Check if the observation builder provides an observation
        if len(observation_dict) < 1:
            warnings.warn(
                "Predictor did not provide any predicted cells to render. \
                Observation builder needs to populate: env.dev_obs_dict")
        else:
            for agent in agent_handles:
                color = self.gl.get_agent_color(agent)
                for visited_cell in observation_dict[agent]:
                    cell_coord = array(visited_cell[:2])
                    cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
                    self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100) 
[docs]
    def render_prediction(self, agent_handles, prediction_dict):
        """
        Render the extent of the observation of each agent. All cells that appear in the agent
        observation will be highlighted.
        :param agent_handles: List of agent indices to adapt color and get correct observation
        :param observation_dict: dictionary containing sets of cells of the agent observation
        """
        rt = self.__class__
        if len(prediction_dict) < 1:
            warnings.warn(
                "Predictor did not provide any predicted cells to render. \
                Predictors builder needs to populate: env.dev_pred_dict")
        else:
            for agent in agent_handles:
                color = self.gl.get_agent_color(agent)
                for visited_cell in prediction_dict[agent]:
                    cell_coord = array(visited_cell[:2])
                    if type(self.gl) is PILSVG:
                        # TODO : Track highlighting (Adrian)
                        r = cell_coord[0]
                        c = cell_coord[1]
                        transitions = self.env.rail.grid[r, c]
                        self.gl.set_predicion_path_at(r, c, transitions, agent_rail_color=color)
                    else:
                        cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
                        self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100) 
[docs]
    def render_rail(self, spacing=False, rail_color="gray", curves=True, arrows=False):
        cell_size = 1  # TODO: remove cell_size
        env = self.env
        # Draw cells grid
        grid_color = [0.95, 0.95, 0.95]
        for row in range(env.height + 1):
            self.gl.plot([0, (env.width + 1) * cell_size],
                         [-row * cell_size, -row * cell_size],
                         color=grid_color, linewidth=2)
        for col in range(env.width + 1):
            self.gl.plot([col * cell_size, col * cell_size],
                         [0, -(env.height + 1) * cell_size],
                         color=grid_color, linewidth=2)
        # Draw each cell independently
        for row in range(env.height):
            for col in range(env.width):
                # bounding box of the grid cell
                x0 = cell_size * col  # left
                x1 = cell_size * (col + 1)  # right
                y0 = cell_size * -row  # top
                y1 = cell_size * -(row + 1)  # bottom
                # centres of cell edges
                coords = [
                    ((x0 + x1) / 2.0, y0),  # N middle top
                    (x1, (y0 + y1) / 2.0),  # E middle right
                    ((x0 + x1) / 2.0, y1),  # S middle bottom
                    (x0, (y0 + y1) / 2.0)  # W middle left
                ]
                # cell centre
                center_xy = array([x0, y1]) + cell_size / 2
                # cell transition values
                cell = env.rail.get_full_transitions(row, col)
                cell_valid = env.rail.cell_neighbours_valid((row, col), check_this_cell=True)
                # Special Case 7, with a single bit; terminate at center
                nbits = 0
                tmp = cell
                while tmp > 0:
                    nbits += (tmp & 1)
                    tmp = tmp >> 1
                # as above - move the from coord to the centre
                # it's a dead env.
                is_dead_end = nbits == 1
                if not cell_valid:
                    self.gl.scatter(*center_xy, color="r", s=30)
                for orientation in range(4):  # ori is where we're heading
                    from_ori = (orientation + 2) % 4  # 0123=NESW -> 2301=SWNE
                    from_xy = coords[from_ori]
                    moves = env.rail.get_transitions(((row, col), orientation))
                    for to_ori in range(4):
                        to_xy = coords[to_ori]
                        rotation = (to_ori - from_ori) % 4
                        if (moves[to_ori]):  # if we have this transition
                            self.draw_transition(
                                array([from_xy, to_xy]), center_xy,
                                rotation, dead_end=is_dead_end, curves=curves and not is_dead_end, spacing=spacing,
                                color=rail_color) 
[docs]
    def render_env(self,
                   show=False,  # whether to call matplotlib show() or equivalent after completion
                   show_agents=True,  # whether to include agents
                   show_inactive_agents=False,
                   show_observations=True,  # whether to include observations
                   show_predictions=False,  # whether to include predictions
                   show_rowcols=False,  # label the rows and columns
                   frames=False,  # frame counter to show (intended since invocation)
                   episode=None,  # int episode number to show
                   step=None,  # int step number to show in image
                   selected_agent=None,  # indicate which agent is "selected" in the editor
                   return_image=False): # indicate if image is returned for use in monitor:
        """ Draw the environment using the GraphicsLayer this RenderTool was created with.
            (Use show=False from a Jupyter notebook with %matplotlib inline)
        """
        # if type(self.gl) is PILSVG:
        if self.gl_str in ["PILSVG", "PGL"]:
            return self.render_env_svg(show=show,
                                show_observations=show_observations,
                                show_predictions=show_predictions,
                                selected_agent=selected_agent,
                                show_agents=show_agents,
                                show_inactive_agents=show_inactive_agents,
                                show_rowcols=show_rowcols,
                                return_image=return_image
                                )
        else:
            return self.render_env_pil(show=show,
                                show_agents=show_agents,
                                show_inactive_agents=show_inactive_agents,
                                show_observations=show_observations,
                                show_predictions=show_predictions,
                                show_rowcols=show_rowcols,
                                frames=frames,
                                episode=episode,
                                step=step,
                                selected_agent=selected_agent,
                                return_image=return_image
                                ) 
    def _draw_square(self, center, size, color, opacity=255, layer=0):
        x0 = center[0] - size / 2
        x1 = center[0] + size / 2
        y0 = center[1] - size / 2
        y1 = center[1] + size / 2
        self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color, layer=layer, opacity=opacity)
[docs]
    def get_image(self):
        return self.gl.get_image() 
[docs]
    def render_env_pil(self,
                       show=False,  # whether to call matplotlib show() or equivalent after completion
                       # use false when calling from Jupyter.  (and matplotlib no longer supported!)
                       show_agents=True,  # whether to include agents
                       show_inactive_agents=False,
                       show_observations=True,  # whether to include observations
                       show_predictions=False,  # whether to include predictions
                       show_rowcols=False, # label the rows and columns
                       frames=False,  # frame counter to show (intended since invocation)
                       episode=None,  # int episode number to show
                       step=None,  # int step number to show in image
                       selected_agent=None,  # indicate which agent is "selected" in the editor
                       return_image=False # indicate if image is returned for use in monitor:
                       ):
        if type(self.gl) is PILGL:
            self.gl.begin_frame()
        env = self.env
        self.render_rail()
        # Draw each agent + its orientation + its target
        if show_agents:
            self.plot_agents(targets=True, selected_agent=selected_agent)
        if show_observations:
            self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
        if show_predictions and len(env.dev_pred_dict) > 0:
            self.render_prediction(range(env.get_num_agents()), env.dev_pred_dict)
        # Draw some textual information like fps
        text_y = [-0.3, -0.6, -0.9]
        if frames:
            self.gl.text(0.1, text_y[2], "Frame:{:}".format(self.frame_nr))
        self.frame_nr += 1
        if episode is not None:
            self.gl.text(0.1, text_y[1], "Ep:{}".format(episode))
        if step is not None:
            self.gl.text(0.1, text_y[0], "Step:{}".format(step))
        time_now = time.time()
        self.gl.text(2, text_y[2], "elapsed:{:.2f}s".format(time_now - self.start_time))
        self.times_list.append(time_now)
        if len(self.times_list) > 20:
            self.times_list.popleft()
        if len(self.times_list) > 1:
            rFps = (len(self.times_list) - 1) / (self.times_list[-1] - self.times_list[0])
            self.gl.text(2, text_y[1], "fps:{:.2f}".format(rFps))
        self.gl.prettify2(env.width, env.height, self.pix_per_cell)
        # TODO: for MPL, we don't want to call clf (called by endframe)
        # if not show:
        if show and type(self.gl) is PILGL:
            self.gl.show()
        self.gl.pause(0.00001)
        if return_image:
            return self.get_image()
        return 
[docs]
    def render_env_svg(
        self, show=False, show_observations=True, show_predictions=False, selected_agent=None,
        show_agents=True, show_inactive_agents=False, show_rowcols=False, return_image=False
    ):
        """
        Renders the environment with SVG support (nice image)
        """
        env = self.env
        self.gl.begin_frame()
        if self.new_rail:
            self.new_rail = False
            self.gl.clear_rails()
            # store the targets
            targets = {}
            selected = {}
            for agent_idx, agent in enumerate(self.env.agents):
                if agent is None:
                    continue
                targets[tuple(agent.target)] = agent_idx
                selected[tuple(agent.target)] = (agent_idx == selected_agent)
            # Draw each cell independently
            for r in range(env.height):
                for c in range(env.width):
                    transitions = env.rail.grid[r, c]
                    if (r, c) in targets:
                        target = targets[(r, c)]
                        is_selected = selected[(r, c)]
                    else:
                        target = None
                        is_selected = False
                    self.gl.set_rail_at(r, c, transitions, target=target, is_selected=is_selected,
                                        rail_grid=env.rail.grid, num_agents=env.get_num_agents(),
                                        show_debug=self.show_debug)
            self.gl.build_background_map(targets)
            if show_rowcols:
                # label rows, cols
                for iRow in range(env.height):
                    self.gl.text_rowcol((iRow, 0), str(iRow), layer=self.gl.RAIL_LAYER)
                for iCol in range(env.width):
                    self.gl.text_rowcol((0, iCol), str(iCol), layer=self.gl.RAIL_LAYER)
        if show_agents:
            for agent_idx, agent in enumerate(self.env.agents):
                if agent is None:
                    continue
                # Show an agent even if it hasn't already started
                if agent.position is None:
                    if show_inactive_agents:
                        # print("agent ", agent_idx, agent.position, agent.old_position, agent.initial_position)
                        self.gl.set_agent_at(agent_idx, *(agent.initial_position),
                            agent.initial_direction, agent.initial_direction,
                            is_selected=(selected_agent == agent_idx),
                            rail_grid=env.rail.grid,
                            show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
                            malfunction=False)
                    continue
                is_malfunction = agent.malfunction_handler.malfunction_down_counter > 0
                if self.agent_render_variant == AgentRenderVariant.BOX_ONLY:
                    self.gl.set_cell_occupied(agent_idx, *(agent.position))
                elif self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND or \
                    
self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:  # noqa: E125
                    # Most common case - the agent has been running for >1 steps
                    if agent.old_position is not None:
                        position = agent.old_position
                        direction = agent.direction
                        old_direction = agent.old_direction
                    # the agent's first step - it doesn't have an old position yet
                    elif agent.position is not None:
                        position = agent.position
                        direction = agent.direction
                        old_direction = agent.direction
                    # When the editor has just added an agent
                    elif agent.initial_position is not None:
                        position = agent.initial_position
                        direction = agent.initial_direction
                        old_direction = agent.initial_direction
                    # set_agent_at uses the agent index for the color
                    if self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:
                        self.gl.set_cell_occupied(agent_idx, *(agent.position))
                    self.gl.set_agent_at(agent_idx, *position, old_direction, direction,
                                         selected_agent == agent_idx, rail_grid=env.rail.grid,
                                         show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
                                         malfunction=is_malfunction)
                else:
                    position = agent.position
                    direction = agent.direction
                    for possible_direction in range(4):
                        # Is a transition along movement `desired_movement_from_new_cell` to the current cell possible?
                        isValid = env.rail.get_transition((agent.position, agent.direction), possible_direction)
                        if isValid:
                            direction = possible_direction
                            # set_agent_at uses the agent index for the color
                            self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,
                                                 selected_agent == agent_idx, rail_grid=env.rail.grid,
                                                 show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
                                                 malfunction=is_malfunction)
                    # set_agent_at uses the agent index for the color
                    if self.agent_render_variant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX:
                        self.gl.set_cell_occupied(agent_idx, *(agent.position))
                    if show_inactive_agents:
                        show_this_agent = True
                    else:
                        show_this_agent = agent.state.is_on_map_state()
                    if show_this_agent:
                        self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,
                                        selected_agent == agent_idx,
                                        rail_grid=env.rail.grid, malfunction=is_malfunction)
        if show_observations:
            self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
        if show_predictions:
            self.render_prediction(range(env.get_num_agents()), env.dev_pred_dict)
        if show:
            self.gl.show()
        for i in range(3):
            self.gl.process_events()
        self.frame_nr += 1
        if return_image:
            return self.get_image()
        return 
[docs]
    def close_window(self):
        self.gl.close_window()