"""This module contains utility functions making it easier to debug POMDP
planning.
TreeDebugger
************
The core debugging functionality for POMCP/POUCT search trees is incorporated
into the TreeDebugger.  It is designed for ease of use during a :code:`pdb` or
:code:`ipdb` debugging session. Here is a minimal example usage:
.. code-block:: python
   from pomdp_py.utils import TreeDebugger
   from pomdp_problems.tiger import TigerProblem
   # pomdp_py.Agent
   agent = TigerProblem.create("tiger-left", 0.5, 0.15).agent
   # suppose pouct is a pomdp_py.POUCT object (POMCP works too)
   pouct = pomdp_py.POUCT(max_depth=4, discount_factor=0.95,
                          num_sims=4096, exploration_const=200,
                          rollout_policy=tiger_problem.agent.policy_model)
   action = pouct.plan(agent)
   dd = TreeDebugger(agent.tree)
   import pdb; pdb.set_trace()
When the program executes, you enter the pdb debugger, and you can:
.. code-block:: text
    (Pdb) dd.pp
    _VNodePP(n=4095, v=-19.529)(depth=0)
    ├─── ₀listen⟶_QNodePP(n=4059, v=-19.529)
    │    ├─── ₀tiger-left⟶_VNodePP(n=2013, v=-16.586)(depth=1)
    │    │    ├─── ₀listen⟶_QNodePP(n=1883, v=-16.586)
    │    │    │    ├─── ₀tiger-left⟶_VNodePP(n=1441, v=-8.300)(depth=2)
    ... # prints out the entire tree; Colored in terminal.
    (Pdb) dd.p(1)
    _VNodePP(n=4095, v=-19.529)(depth=0)
    ├─── ₀listen⟶_QNodePP(n=4059, v=-19.529)
    │    ├─── ₀tiger-left⟶_VNodePP(n=2013, v=-16.586)(depth=1)
    │    │    ├─── ₀listen⟶_QNodePP(n=1883, v=-16.586)
    │    │    ├─── ₁open-left⟶_QNodePP(n=18, v=-139.847)
    │    │    └─── ₂open-right⟶_QNodePP(n=112, v=-57.191)
    ... # prints up to depth 1
Note that the printed texts are colored in the terminal.
You can retrieve the subtree through indexing:
.. code-block:: text
    (Pdb) dd[0]
    listen⟶_QNodePP(n=4059, v=-19.529)
        - [0] tiger-left: VNode(n=2013, v=-16.586)
        - [1] tiger-right: VNode(n=2044, v=-16.160)
    (Pdb) dd[0][1][2]
    open-right⟶_QNodePP(n=15, v=-148.634)
        - [0] tiger-left: VNode(n=7, v=-20.237)
        - [1] tiger-right: VNode(n=6, v=8.500)
You can obtain the currently preferred action sequence by:
.. code-block:: text
    (Pdb) dd.mbp
       listen  []
       listen  []
       listen  []
       listen  []
       open-left  []
     _VNodePP(n=4095, v=-19.529)(depth=0)
     ├─── ₀listen⟶_QNodePP(n=4059, v=-19.529)
     │    └─── ₁tiger-right⟶_VNodePP(n=2044, v=-16.160)(depth=1)
     │         ├─── ₀listen⟶_QNodePP(n=1955, v=-16.160)
     │         │    └─── ₁tiger-right⟶_VNodePP(n=1441, v=-8.300)(depth=2)
     │         │         ├─── ₀listen⟶_QNodePP(n=947, v=-8.300)
     │         │         │    └─── ₁tiger-right⟶_VNodePP(n=768, v=0.022)(depth=3)
     │         │         │         ├─── ₀listen⟶_QNodePP(n=462, v=0.022)
     │         │         │         │    └─── ₁tiger-right⟶_VNodePP(n=395, v=10.000)(depth=4)
     │         │         │         │         ├─── ₁open-left⟶_QNodePP(n=247, v=10.000)
:code:`mbp` stands for "mark best plan".
To explore more features, browse the list of methods in the documentation.
"""
import sys
from pomdp_py.algorithms.po_uct import TreeNode, QNode, VNode, RootVNode
from pomdp_py.utils import typ, similar, special_char
SIMILAR_THRESH = 0.6
DEFAULT_MARK_COLOR = "blue"
MARKED = {}  # tracks marked nodes on tree
def _node_pp(node, e=None, p=None, o=None):
    # We want to return the node, but we don't want to print it on pdb with
    # its default string. But instead, we want to print it with our own
    # string formatting.
    if isinstance(node, VNode):
        return _VNodePP(node, parent_edge=e, parent=p, original=o)
    else:
        return _QNodePP(node, parent_edge=e, parent=p, original=o)
class _NodePP:
    def __init__(self, node, parent_edge=None, parent=None, original=None):
        """node: either VNode or QNode (the actual node on the tree)"""
        self.parent_edge = parent_edge
        self.parent = parent
        self.children = node.children
        self.print_children = True
        if original is None:
            self.original = node
        else:
            self.original = original
    def __hash__(self):
        return id(self.original)
    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return id(self.original) == id(other.original)
        else:
            return False
    @property
    def marked(self):
        return id(self.original) in MARKED
    def to_edge(self, key):
        if key in self.children:
            return key
        elif type(key) == int:
            edges = list(sorted_by_str(self.children.keys()))
            return edges[key]
        elif type(key) == str:
            chosen = max(self.children.keys(), key=lambda edge: similar(str(edge), key))
            if similar(str(chosen), key) >= SIMILAR_THRESH:
                return chosen
        raise ValueError("Cannot access children with key {}".format(key))
    def __getitem__(self, key):
        """
        When debugging, you can access the child of a node by the key
        of the following types:
        - the key is an action or observation object that points to a child;
          that is, key in self.children is True.
        - the key is an integer corresponding to the list of children shown
          when printing the node in the debugger
        - the key is a string that is similar to the string
          version of any of the action or observation edges;
          the most similar one will be chosen; The threshold
          of similarity is SIMILAR_THRESH
        """
        edge = self.to_edge(key)
        c = self.children[edge]
        if isinstance(c, _NodePP):
            original = c.original
        else:
            original = None
        return _node_pp(c, e=edge, p=self, o=original)
    def __contains__(self, key):
        try:
            self.to_edge(key)
            return True
        except ValueError:
            return False
    @staticmethod
    def interpret_print_type(opt):
        if opt.startswith("b") or opt.startswith("m"):
            opt = "marked-only"
        elif opt.startswith("s"):
            opt = "summary"
        elif opt.startswith("c"):
            opt = "complete"
        else:
            raise ValueError("Cannot understand print type: {}".format(opt))
        return opt
    def p(self, opt=None, **kwargs):
        if opt is None:
            max_depth = None
            print_type_opt = kwargs.get("t", "summary")
        elif type(opt) == int:
            max_depth = opt
            print_type_opt = kwargs.get("t", "summary")
        elif type(opt) == str:
            print_type_opt = opt
            max_depth = kwargs.get("d", None)
        else:
            raise ValueError("Cannot deal with opt of type {}".format(type(opt)))
        self.print_tree(
            max_depth=max_depth, print_type=_NodePP.interpret_print_type(print_type_opt)
        )
    @property
    def pp(self):
        self.print_tree(max_depth=None)
    def print_tree(self, **options):
        """Prints the tree, rooted at self"""
        _NodePP._print_tree_helper(self, 0, "", [None], -1, **options)
    @staticmethod
    def _print_tree_helper(
        root,
        depth,  # depth of root
        parent_edge,
        branch_positions,  # list of 'first', 'middle', 'last' for each level prior to root
        child_index,  # Index of the root as a child of parent
        max_depth=None,
        print_type="summary",
    ):
        """
        pos_among_children is either 'first', 'middle', or 'last'
        """
        if max_depth is not None and depth > max_depth:
            return
        if root is None:
            return
        # Print the tree branches for all levels up to current root
        branches = ""
        preceding_positions = branch_positions[
            :-1
        ]  # all positions except for current root
        for pos in preceding_positions:
            if pos is None:
                continue
            elif pos == "first" or pos == "middle":
                branches += "│    "
            else:  # "last"
                branches += "     "
        last_position = branch_positions[-1]
        if last_position is None:
            pass
        elif last_position == "first" or last_position == "middle":
            branches += "├─── "
        else:  # last
            branches += "└─── "
        root.print_children = False
        if child_index >= 0:
            line = (
                branches
                + str(child_index).translate(special_char.SUBSCRIPT)
                + str(root)
            )
        else:
            line = branches + str(root)
        if isinstance(root, VNode):
            line += typ.cyan("(depth=" + str(depth) + ")")
        print(line)
        for i, c in enumerate(sorted_by_str(root.children)):
            skip = True
            if root[c].marked:
                skip = False
            elif print_type == "complete":
                skip = False
            elif root[c].num_visits > 1:
                skip = False
            if print_type == "marked-only" and not root[c].marked:
                skip = True
            if not skip:
                if isinstance(root[c], QNode):
                    next_depth = depth
                else:
                    next_depth = depth + 1
                if i == len(root.children) - 1:
                    next_pos = "last"
                elif i == 0:
                    next_pos = "first"
                else:
                    next_pos = "middle"
                _NodePP._print_tree_helper(
                    root[c],
                    next_depth,
                    c,
                    branch_positions + [next_pos],
                    i,
                    max_depth=max_depth,
                    print_type=print_type,
                )
class _QNodePP(_NodePP, QNode):
    """QNode for better printing"""
    def __init__(self, qnode, **kwargs):
        QNode.__init__(self, qnode.num_visits, qnode.value)
        _NodePP.__init__(self, qnode, **kwargs)
    def __str__(self):
        return TreeDebugger.single_node_str(
            self, parent_edge=self.parent_edge, include_children=self.print_children
        )
class _VNodePP(_NodePP, VNode):
    """VNode for better printing"""
    def __init__(self, vnode, **kwargs):
        VNode.__init__(self, vnode.num_visits)
        _NodePP.__init__(self, vnode, **kwargs)
    def __str__(self):
        return TreeDebugger.single_node_str(
            self, parent_edge=self.parent_edge, include_children=self.print_children
        )
[docs]
class TreeDebugger:
    """
    Helps you debug the search tree; A search tree is a tree
    that contains a subset of future histories, organized into
    QNodes (value represents Q(b,a); children are observations) and
    VNodes (value represents V(b); children are actions).
    """
    def __init__(self, tree):
        """
        Args:
            tree (VNode): the root node of a search tree. For example,
                the tree built by POUCT after planning an action,
                which can be accessed by agent.tree.
        """
        if not isinstance(tree, TreeNode):
            raise ValueError(
                "Expecting tree to be a TreeNode, but got {}".format(type(tree))
            )
        self.tree = _node_pp(tree)
        self.current = self.tree  # points to the node the user is interacting with
        self._stats_cache = {}
    def __str__(self):
        return str(self.current)
    def __repr__(self):
        nodestr = TreeDebugger.single_node_str(
            self.current, parent_edge=self.current.parent_edge
        )
        return "TreeDebugger@\n{}".format(nodestr)
    def __getitem__(self, key):
        if type(key) == tuple:
            n = self.current
            for k in key:
                n = n[k]
            return n
        else:
            return self.current[key]
    def _get_stats(self):
        if id(self.current) in self._stats_cache:
            stats = self._stats_cache[id(self.current)]
        else:
            stats = TreeDebugger.tree_stats(self.current)
            self._stats_cache[id(self.current)] = stats
        return stats
[docs]
    def num_nodes(self, kind="all"):
        """
        Returns the total number of nodes in the tree rooted at "current"
        """
        stats = self._get_stats()
        res = {
            "all": stats["total_vnodes"] + stats["total_qnodes"],
            "q": stats["total_qnodes"],
            "v": stats["total_vnodes"],
        }
        if kind in res:
            return res[kind]
        else:
            raise ValueError(
                "Invalid value for kind={}; Valid values are {}".format(
                    kind, list(res.keys())
                )
            ) 
    @property
    def depth(self):
        """Tree depth starts from 0 (root node only).
        It is the largest number of edges on a path from root to leaf."""
        stats = self._get_stats()
        return stats["max_depth"]
    @property
    def d(self):
        """alias for depth"""
        return self.depth
    @property
    def num_layers(self):
        """Returns the number of layers;
        It is the number of layers of nodes, which equals to depth + 1"""
        return self.depth + 1
    @property
    def nl(self):
        """alias for num_layers"""
        return self.num_layers
    @property
    def nn(self):
        """Returns the total number of nodes in the tree"""
        return self.num_nodes(kind="all")
    @property
    def nq(self):
        """Returns the total number of QNodes in the tree"""
        return self.num_nodes(kind="q")
    @property
    def nv(self):
        """Returns the total number of VNodes in the tree"""
        return self.num_nodes(kind="v")
[docs]
    def l(self, depth, as_debuggers=True):
        """alias for layer"""
        return self.layer(depth, as_debuggers=as_debuggers) 
[docs]
    def layer(self, depth, as_debuggers=True):
        """
        Returns a list of nodes at the given depth. Will only return VNodes.
        Warning: If depth is high, there will likely be a huge number of nodes.
        Args:
            depth (int): Depth of the tree
            as_debuggers (bool): True if return a list of TreeDebugger objects,
                one for each tree on the layer.
        """
        if depth < 0 or depth > self.depth:
            raise ValueError(
                "Depth {} is out of range (0-{})".format(depth, self.depth)
            )
        nodes = []
        self._layer_helper(self.current, 0, depth, nodes)
        return nodes 
    def _layer_helper(
        self, root, current_depth, target_depth, nodes, as_debuggers=True
    ):
        if current_depth == target_depth:
            if isinstance(root, VNode):
                if as_debuggers:
                    nodes.append(TreeDebugger(root.original))
                else:
                    nodes.append(root)
        else:
            for c in sorted_by_str(root.children):
                if isinstance(root[c], QNode):
                    next_depth = current_depth
                else:
                    next_depth = current_depth + 1
                self._layer_helper(root[c], next_depth, target_depth, nodes)
    @property
    def leaf(self):
        worklist = [self.current]
        seen = set({self.current})
        leafs = []
        while len(worklist) > 0:
            node = worklist.pop()
            if len(node.children) == 0:
                leafs.append(node)
            else:
                for c in node.children:
                    if node[c] not in seen:
                        worklist.append(node[c])
                        seen.add(node[c])
        return leafs
[docs]
    def step(self, key):
        """Updates current interaction node to follow the
        edge along key"""
        edge = self.current.to_edge(key)
        self.current = self[edge]
        print("step: " + str(edge)) 
[docs]
    def s(self, key):
        """alias for step"""
        return self.step(key) 
[docs]
    def back(self):
        """move current node of interaction back to parent"""
        self.current = self.current.parent 
    @property
    def b(self):
        """alias for back"""
        self.back()
    @property
    def root(self):
        """The root node when first creating this TreeDebugger"""
        return self.tree
    @property
    def r(self):
        """alias for root"""
        return self.root
    @property
    def c(self):
        """Current node of interaction"""
        return self.current
[docs]
    def p(self, *args, **kwargs):
        """print tree"""
        return self.current.p(*args, **kwargs) 
    @property
    def pp(self):
        """print tree, with preset options"""
        return self.current.pp
    @property
    def mbp(self):
        """Mark Best and Print.
        Mark the best sequence, and then print with only the marked nodes"""
        self.mark(self.bestseq, color="yellow")
        self.p("marked-only")
    @property
    def pm(self):
        """Print marked only"""
        self.p("marked-only")
[docs]
    def mark_sequence(self, seq, color=DEFAULT_MARK_COLOR):
        """
        Given a list of keys (understandable by __getitem__ in _NodePP),
        mark nodes (both QNode and VNode) along the path in the tree.
        Note this sequence starts from self.current; So self.current will
        also be marked.
        """
        node = self.current
        MARKED[id(node.original)] = interpret_color(color)
        for key in seq:
            MARKED[id(node[key].original)] = interpret_color(color)
            node = node[key] 
[docs]
    def mark(self, seq, **kwargs):
        """alias for mark_sequence"""
        return self.mark_sequence(seq, **kwargs) 
[docs]
    def mark_path(self, dest, **kwargs):
        """paths the path to dest node"""
        return self.mark(self.path_to(dest), **kwargs) 
[docs]
    def markp(self, dest, **kwargs):
        """alias to mark_path"""
        return self.mark_path(dest, **kwargs) 
    @property
    def clear(self):
        """Clear the marks"""
        global MARKED
        MARKED = {}
    @property
    def bestseq(self):
        """Returns a list of actions, observation sequence
        that have the highest value for each step. Such
        a sequence is "preferred".
        Also, prints out the list of preferred actions for each step
        into the future"""
        return self.preferred_actions(self.current, max_depth=None)
[docs]
    def bestseqd(self, max_depth):
        """
        alias for bestseq except with
        """
        return self.preferred_actions(self.current, max_depth=max_depth) 
[docs]
    @staticmethod
    def single_node_str(node, parent_edge=None, indent=1, include_children=True):
        """
        Returns a string for printing given a single vnode.
        """
        if hasattr(node, "marked") and node.marked:
            color_fn = MARKED[id(node.original)]
            opposite_color = color = lambda s: typ.bold(color_fn(s))
        elif isinstance(node, VNode):
            color = typ.green
            opposite_color = typ.red
        else:
            assert isinstance(node, QNode)
            color = typ.red
            opposite_color = typ.green
        output = ""
        if parent_edge is not None:
            output += opposite_color(str(parent_edge)) + "⟶"
        output += color(str(node.__class__.__name__)) + "(n={}, v={:.3f})".format(
            node.num_visits, node.value
        )
        if include_children:
            output += "\n"
            for i, action in enumerate(sorted_by_str(node.children)):
                child = node.children[action]
                child_info = TreeDebugger.single_node_str(child, include_children=False)
                spaces = "    " * indent
                output += "{}- [{}] {}: {}".format(
                    spaces, i, typ.white(str(action)), child_info
                )
                if i < len(node.children) - 1:
                    output += "\n"
            output += "\n"
        return output 
[docs]
    @staticmethod
    def preferred_actions(root, max_depth=None):
        """
        Print out the currently preferred actions up to given `max_depth`
        """
        seq = []
        TreeDebugger._preferred_actions_helper(root, 0, seq, max_depth=max_depth)
        return seq 
    @staticmethod
    def _preferred_actions_helper(root, depth, seq, max_depth=None):
        # don't care about last layer action because it's outside of planning
        # horizon and only has initial value.
        if max_depth is not None and depth > max_depth:
            return
        if root is None or len(root.children) == 0:
            return
        best_child = root.to_edge(0)
        best_value = root[0].value
        for c in root.children:
            if root[c].value > best_value:
                best_child = c
                best_value = root[c].value
        seq.append(best_child)
        equally_good = []
        if isinstance(root, VNode):
            for c in root.children:
                if not (c == best_child) and root[c].value == best_value:
                    equally_good.append(c)
        if best_child is not None and root[best_child] is not None:
            if isinstance(root[best_child], QNode):
                print("  %s  %s" % (typ.yellow(str(best_child)), str(equally_good)))
                next_depth = depth
            else:
                next_depth = depth + 1
            TreeDebugger._preferred_actions_helper(
                root[best_child], next_depth, seq, max_depth=max_depth
            )
[docs]
    def path(self, dest):
        """alias for path_to;
        Example usage:
        marking path from root to the first node on the second layer:
            dd.mark(dd.path(dd.layer(2)[0]))
        """
        return self.path_to(dest) 
[docs]
    def path_to(self, dest):
        """Returns a list of keys (actions / observations) that represents the path from
        self.current to the given node `dest`. Returns None if the path does not
        exist.  Uses DFS. Can be useful for marking path to a node to a specific
        layer. Note that the returned path is a list of keys (i.e. edges), not nodes.
        """
        # dest may be in the returned list of layer() which could be a TreeDebugger.
        if isinstance(dest, TreeDebugger):
            dest = dest.current
        worklist = [self.current]
        seen = set({self.current})
        parent = {self.current: None}
        while len(worklist) > 0:
            node = worklist.pop()
            if node == dest:
                return self._get_path(self.current, dest, parent)
            for c in node.children:
                if node[c] not in seen:
                    worklist.append(node[c])
                    seen.add(node[c])
                    parent[node[c]] = (node, c)
        return None 
    def _get_path(self, start, dest, parent):
        """Helper method for path_to"""
        v = dest
        path = []
        while v != start:
            v, edge = parent[v]
            path.append(edge)
        return list(reversed(path))
[docs]
    @staticmethod
    def tree_stats(root, max_depth=None):
        """Gether statistics about the tree"""
        stats = {
            "total_vnodes": 0,
            "total_qnodes": 0,
            "total_vnodes_children": 0,
            "total_qnodes_children": 0,
            "max_vnodes_children": 0,
            "max_qnodes_children": 0,
            "max_depth": 0,
        }
        TreeDebugger._tree_stats_helper(root, 0, stats, max_depth=max_depth)
        stats["num_visits"] = root.num_visits
        stats["value"] = root.value
        return stats 
    @staticmethod
    def _tree_stats_helper(root, depth, stats, max_depth=None):
        if max_depth is not None and depth > max_depth:
            return
        else:
            if isinstance(root, VNode):
                stats["total_vnodes"] += 1
                stats["total_vnodes_children"] += len(root.children)
                stats["max_vnodes_children"] = max(
                    stats["max_vnodes_children"], len(root.children)
                )
                stats["max_depth"] = max(stats["max_depth"], depth)
            else:
                stats["total_qnodes"] += 1
                stats["total_qnodes_children"] += len(root.children)
                stats["max_qnodes_children"] = max(
                    stats["max_qnodes_children"], len(root.children)
                )
            for c in root.children:
                if isinstance(root[c], QNode):
                    next_depth = depth
                else:
                    next_depth = depth + 1
                TreeDebugger._tree_stats_helper(
                    root[c], next_depth, stats, max_depth=max_depth
                ) 
[docs]
def sorted_by_str(enumerable):
    return sorted(enumerable, key=lambda n: str(n)) 
[docs]
def interpret_color(colorstr):
    if colorstr.lower() in typ.colors:
        return eval("typ.{}".format(colorstr))
    else:
        raise ValueError(
            "Invalid color: {};\nThe available ones are {}".format(colorstr, typ.colors)
        )