"""This module contains utility functions making it easier to debug POMDP
planning.
TreeDebugger
************
The core debugging functionality for POMCP/POUCT search trees is incorporated
into the TreeDebugger. It is designed for ease of use during a :code:`pdb` or
:code:`ipdb` debugging session. Here is a minimal example usage:
.. code-block:: python
from pomdp_py.utils import TreeDebugger
from pomdp_problems.tiger import TigerProblem
# pomdp_py.Agent
agent = TigerProblem.create("tiger-left", 0.5, 0.15).agent
# suppose pouct is a pomdp_py.POUCT object (POMCP works too)
pouct = pomdp_py.POUCT(max_depth=4, discount_factor=0.95,
num_sims=4096, exploration_const=200,
rollout_policy=tiger_problem.agent.policy_model)
action = pouct.plan(agent)
dd = TreeDebugger(agent.tree)
import pdb; pdb.set_trace()
When the program executes, you enter the pdb debugger, and you can:
.. code-block:: text
(Pdb) dd.pp
_VNodePP(n=4095, v=-19.529)(depth=0)
├─── ₀listen⟶_QNodePP(n=4059, v=-19.529)
│ ├─── ₀tiger-left⟶_VNodePP(n=2013, v=-16.586)(depth=1)
│ │ ├─── ₀listen⟶_QNodePP(n=1883, v=-16.586)
│ │ │ ├─── ₀tiger-left⟶_VNodePP(n=1441, v=-8.300)(depth=2)
... # prints out the entire tree; Colored in terminal.
(Pdb) dd.p(1)
_VNodePP(n=4095, v=-19.529)(depth=0)
├─── ₀listen⟶_QNodePP(n=4059, v=-19.529)
│ ├─── ₀tiger-left⟶_VNodePP(n=2013, v=-16.586)(depth=1)
│ │ ├─── ₀listen⟶_QNodePP(n=1883, v=-16.586)
│ │ ├─── ₁open-left⟶_QNodePP(n=18, v=-139.847)
│ │ └─── ₂open-right⟶_QNodePP(n=112, v=-57.191)
... # prints up to depth 1
Note that the printed texts are colored in the terminal.
You can retrieve the subtree through indexing:
.. code-block:: text
(Pdb) dd[0]
listen⟶_QNodePP(n=4059, v=-19.529)
- [0] tiger-left: VNode(n=2013, v=-16.586)
- [1] tiger-right: VNode(n=2044, v=-16.160)
(Pdb) dd[0][1][2]
open-right⟶_QNodePP(n=15, v=-148.634)
- [0] tiger-left: VNode(n=7, v=-20.237)
- [1] tiger-right: VNode(n=6, v=8.500)
You can obtain the currently preferred action sequence by:
.. code-block:: text
(Pdb) dd.mbp
listen []
listen []
listen []
listen []
open-left []
_VNodePP(n=4095, v=-19.529)(depth=0)
├─── ₀listen⟶_QNodePP(n=4059, v=-19.529)
│ └─── ₁tiger-right⟶_VNodePP(n=2044, v=-16.160)(depth=1)
│ ├─── ₀listen⟶_QNodePP(n=1955, v=-16.160)
│ │ └─── ₁tiger-right⟶_VNodePP(n=1441, v=-8.300)(depth=2)
│ │ ├─── ₀listen⟶_QNodePP(n=947, v=-8.300)
│ │ │ └─── ₁tiger-right⟶_VNodePP(n=768, v=0.022)(depth=3)
│ │ │ ├─── ₀listen⟶_QNodePP(n=462, v=0.022)
│ │ │ │ └─── ₁tiger-right⟶_VNodePP(n=395, v=10.000)(depth=4)
│ │ │ │ ├─── ₁open-left⟶_QNodePP(n=247, v=10.000)
:code:`mbp` stands for "mark best plan".
To explore more features, browse the list of methods in the documentation.
"""
import sys
from pomdp_py.algorithms.po_uct import TreeNode, QNode, VNode, RootVNode
from pomdp_py.utils import typ, similar, special_char
SIMILAR_THRESH = 0.6
DEFAULT_MARK_COLOR = "blue"
MARKED = {} # tracks marked nodes on tree
def _node_pp(node, e=None, p=None, o=None):
# We want to return the node, but we don't want to print it on pdb with
# its default string. But instead, we want to print it with our own
# string formatting.
if isinstance(node, VNode):
return _VNodePP(node, parent_edge=e, parent=p, original=o)
else:
return _QNodePP(node, parent_edge=e, parent=p, original=o)
class _NodePP:
def __init__(self, node, parent_edge=None, parent=None, original=None):
"""node: either VNode or QNode (the actual node on the tree)"""
self.parent_edge = parent_edge
self.parent = parent
self.children = node.children
self.print_children = True
if original is None:
self.original = node
else:
self.original = original
def __hash__(self):
return id(self.original)
def __eq__(self, other):
if isinstance(other, self.__class__):
return id(self.original) == id(other.original)
else:
return False
@property
def marked(self):
return id(self.original) in MARKED
def to_edge(self, key):
if key in self.children:
return key
elif type(key) == int:
edges = list(sorted_by_str(self.children.keys()))
return edges[key]
elif type(key) == str:
chosen = max(self.children.keys(), key=lambda edge: similar(str(edge), key))
if similar(str(chosen), key) >= SIMILAR_THRESH:
return chosen
raise ValueError("Cannot access children with key {}".format(key))
def __getitem__(self, key):
"""
When debugging, you can access the child of a node by the key
of the following types:
- the key is an action or observation object that points to a child;
that is, key in self.children is True.
- the key is an integer corresponding to the list of children shown
when printing the node in the debugger
- the key is a string that is similar to the string
version of any of the action or observation edges;
the most similar one will be chosen; The threshold
of similarity is SIMILAR_THRESH
"""
edge = self.to_edge(key)
c = self.children[edge]
if isinstance(c, _NodePP):
original = c.original
else:
original = None
return _node_pp(c, e=edge, p=self, o=original)
def __contains__(self, key):
try:
self.to_edge(key)
return True
except ValueError:
return False
@staticmethod
def interpret_print_type(opt):
if opt.startswith("b") or opt.startswith("m"):
opt = "marked-only"
elif opt.startswith("s"):
opt = "summary"
elif opt.startswith("c"):
opt = "complete"
else:
raise ValueError("Cannot understand print type: {}".format(opt))
return opt
def p(self, opt=None, **kwargs):
if opt is None:
max_depth = None
print_type_opt = kwargs.get("t", "summary")
elif type(opt) == int:
max_depth = opt
print_type_opt = kwargs.get("t", "summary")
elif type(opt) == str:
print_type_opt = opt
max_depth = kwargs.get("d", None)
else:
raise ValueError("Cannot deal with opt of type {}".format(type(opt)))
self.print_tree(
max_depth=max_depth, print_type=_NodePP.interpret_print_type(print_type_opt)
)
@property
def pp(self):
self.print_tree(max_depth=None)
def print_tree(self, **options):
"""Prints the tree, rooted at self"""
_NodePP._print_tree_helper(self, 0, "", [None], -1, **options)
@staticmethod
def _print_tree_helper(
root,
depth, # depth of root
parent_edge,
branch_positions, # list of 'first', 'middle', 'last' for each level prior to root
child_index, # Index of the root as a child of parent
max_depth=None,
print_type="summary",
):
"""
pos_among_children is either 'first', 'middle', or 'last'
"""
if max_depth is not None and depth > max_depth:
return
if root is None:
return
# Print the tree branches for all levels up to current root
branches = ""
preceding_positions = branch_positions[
:-1
] # all positions except for current root
for pos in preceding_positions:
if pos is None:
continue
elif pos == "first" or pos == "middle":
branches += "│ "
else: # "last"
branches += " "
last_position = branch_positions[-1]
if last_position is None:
pass
elif last_position == "first" or last_position == "middle":
branches += "├─── "
else: # last
branches += "└─── "
root.print_children = False
if child_index >= 0:
line = (
branches
+ str(child_index).translate(special_char.SUBSCRIPT)
+ str(root)
)
else:
line = branches + str(root)
if isinstance(root, VNode):
line += typ.cyan("(depth=" + str(depth) + ")")
print(line)
for i, c in enumerate(sorted_by_str(root.children)):
skip = True
if root[c].marked:
skip = False
elif print_type == "complete":
skip = False
elif root[c].num_visits > 1:
skip = False
if print_type == "marked-only" and not root[c].marked:
skip = True
if not skip:
if isinstance(root[c], QNode):
next_depth = depth
else:
next_depth = depth + 1
if i == len(root.children) - 1:
next_pos = "last"
elif i == 0:
next_pos = "first"
else:
next_pos = "middle"
_NodePP._print_tree_helper(
root[c],
next_depth,
c,
branch_positions + [next_pos],
i,
max_depth=max_depth,
print_type=print_type,
)
class _QNodePP(_NodePP, QNode):
"""QNode for better printing"""
def __init__(self, qnode, **kwargs):
QNode.__init__(self, qnode.num_visits, qnode.value)
_NodePP.__init__(self, qnode, **kwargs)
def __str__(self):
return TreeDebugger.single_node_str(
self, parent_edge=self.parent_edge, include_children=self.print_children
)
class _VNodePP(_NodePP, VNode):
"""VNode for better printing"""
def __init__(self, vnode, **kwargs):
VNode.__init__(self, vnode.num_visits)
_NodePP.__init__(self, vnode, **kwargs)
def __str__(self):
return TreeDebugger.single_node_str(
self, parent_edge=self.parent_edge, include_children=self.print_children
)
[docs]
class TreeDebugger:
"""
Helps you debug the search tree; A search tree is a tree
that contains a subset of future histories, organized into
QNodes (value represents Q(b,a); children are observations) and
VNodes (value represents V(b); children are actions).
"""
def __init__(self, tree):
"""
Args:
tree (VNode): the root node of a search tree. For example,
the tree built by POUCT after planning an action,
which can be accessed by agent.tree.
"""
if not isinstance(tree, TreeNode):
raise ValueError(
"Expecting tree to be a TreeNode, but got {}".format(type(tree))
)
self.tree = _node_pp(tree)
self.current = self.tree # points to the node the user is interacting with
self._stats_cache = {}
def __str__(self):
return str(self.current)
def __repr__(self):
nodestr = TreeDebugger.single_node_str(
self.current, parent_edge=self.current.parent_edge
)
return "TreeDebugger@\n{}".format(nodestr)
def __getitem__(self, key):
if type(key) == tuple:
n = self.current
for k in key:
n = n[k]
return n
else:
return self.current[key]
def _get_stats(self):
if id(self.current) in self._stats_cache:
stats = self._stats_cache[id(self.current)]
else:
stats = TreeDebugger.tree_stats(self.current)
self._stats_cache[id(self.current)] = stats
return stats
[docs]
def num_nodes(self, kind="all"):
"""
Returns the total number of nodes in the tree rooted at "current"
"""
stats = self._get_stats()
res = {
"all": stats["total_vnodes"] + stats["total_qnodes"],
"q": stats["total_qnodes"],
"v": stats["total_vnodes"],
}
if kind in res:
return res[kind]
else:
raise ValueError(
"Invalid value for kind={}; Valid values are {}".format(
kind, list(res.keys())
)
)
@property
def depth(self):
"""Tree depth starts from 0 (root node only).
It is the largest number of edges on a path from root to leaf."""
stats = self._get_stats()
return stats["max_depth"]
@property
def d(self):
"""alias for depth"""
return self.depth
@property
def num_layers(self):
"""Returns the number of layers;
It is the number of layers of nodes, which equals to depth + 1"""
return self.depth + 1
@property
def nl(self):
"""alias for num_layers"""
return self.num_layers
@property
def nn(self):
"""Returns the total number of nodes in the tree"""
return self.num_nodes(kind="all")
@property
def nq(self):
"""Returns the total number of QNodes in the tree"""
return self.num_nodes(kind="q")
@property
def nv(self):
"""Returns the total number of VNodes in the tree"""
return self.num_nodes(kind="v")
[docs]
def l(self, depth, as_debuggers=True):
"""alias for layer"""
return self.layer(depth, as_debuggers=as_debuggers)
[docs]
def layer(self, depth, as_debuggers=True):
"""
Returns a list of nodes at the given depth. Will only return VNodes.
Warning: If depth is high, there will likely be a huge number of nodes.
Args:
depth (int): Depth of the tree
as_debuggers (bool): True if return a list of TreeDebugger objects,
one for each tree on the layer.
"""
if depth < 0 or depth > self.depth:
raise ValueError(
"Depth {} is out of range (0-{})".format(depth, self.depth)
)
nodes = []
self._layer_helper(self.current, 0, depth, nodes)
return nodes
def _layer_helper(
self, root, current_depth, target_depth, nodes, as_debuggers=True
):
if current_depth == target_depth:
if isinstance(root, VNode):
if as_debuggers:
nodes.append(TreeDebugger(root.original))
else:
nodes.append(root)
else:
for c in sorted_by_str(root.children):
if isinstance(root[c], QNode):
next_depth = current_depth
else:
next_depth = current_depth + 1
self._layer_helper(root[c], next_depth, target_depth, nodes)
@property
def leaf(self):
worklist = [self.current]
seen = set({self.current})
leafs = []
while len(worklist) > 0:
node = worklist.pop()
if len(node.children) == 0:
leafs.append(node)
else:
for c in node.children:
if node[c] not in seen:
worklist.append(node[c])
seen.add(node[c])
return leafs
[docs]
def step(self, key):
"""Updates current interaction node to follow the
edge along key"""
edge = self.current.to_edge(key)
self.current = self[edge]
print("step: " + str(edge))
[docs]
def s(self, key):
"""alias for step"""
return self.step(key)
[docs]
def back(self):
"""move current node of interaction back to parent"""
self.current = self.current.parent
@property
def b(self):
"""alias for back"""
self.back()
@property
def root(self):
"""The root node when first creating this TreeDebugger"""
return self.tree
@property
def r(self):
"""alias for root"""
return self.root
@property
def c(self):
"""Current node of interaction"""
return self.current
[docs]
def p(self, *args, **kwargs):
"""print tree"""
return self.current.p(*args, **kwargs)
@property
def pp(self):
"""print tree, with preset options"""
return self.current.pp
@property
def mbp(self):
"""Mark Best and Print.
Mark the best sequence, and then print with only the marked nodes"""
self.mark(self.bestseq, color="yellow")
self.p("marked-only")
@property
def pm(self):
"""Print marked only"""
self.p("marked-only")
[docs]
def mark_sequence(self, seq, color=DEFAULT_MARK_COLOR):
"""
Given a list of keys (understandable by __getitem__ in _NodePP),
mark nodes (both QNode and VNode) along the path in the tree.
Note this sequence starts from self.current; So self.current will
also be marked.
"""
node = self.current
MARKED[id(node.original)] = interpret_color(color)
for key in seq:
MARKED[id(node[key].original)] = interpret_color(color)
node = node[key]
[docs]
def mark(self, seq, **kwargs):
"""alias for mark_sequence"""
return self.mark_sequence(seq, **kwargs)
[docs]
def mark_path(self, dest, **kwargs):
"""paths the path to dest node"""
return self.mark(self.path_to(dest), **kwargs)
[docs]
def markp(self, dest, **kwargs):
"""alias to mark_path"""
return self.mark_path(dest, **kwargs)
@property
def clear(self):
"""Clear the marks"""
global MARKED
MARKED = {}
@property
def bestseq(self):
"""Returns a list of actions, observation sequence
that have the highest value for each step. Such
a sequence is "preferred".
Also, prints out the list of preferred actions for each step
into the future"""
return self.preferred_actions(self.current, max_depth=None)
[docs]
def bestseqd(self, max_depth):
"""
alias for bestseq except with
"""
return self.preferred_actions(self.current, max_depth=max_depth)
[docs]
@staticmethod
def single_node_str(node, parent_edge=None, indent=1, include_children=True):
"""
Returns a string for printing given a single vnode.
"""
if hasattr(node, "marked") and node.marked:
color_fn = MARKED[id(node.original)]
opposite_color = color = lambda s: typ.bold(color_fn(s))
elif isinstance(node, VNode):
color = typ.green
opposite_color = typ.red
else:
assert isinstance(node, QNode)
color = typ.red
opposite_color = typ.green
output = ""
if parent_edge is not None:
output += opposite_color(str(parent_edge)) + "⟶"
output += color(str(node.__class__.__name__)) + "(n={}, v={:.3f})".format(
node.num_visits, node.value
)
if include_children:
output += "\n"
for i, action in enumerate(sorted_by_str(node.children)):
child = node.children[action]
child_info = TreeDebugger.single_node_str(child, include_children=False)
spaces = " " * indent
output += "{}- [{}] {}: {}".format(
spaces, i, typ.white(str(action)), child_info
)
if i < len(node.children) - 1:
output += "\n"
output += "\n"
return output
[docs]
@staticmethod
def preferred_actions(root, max_depth=None):
"""
Print out the currently preferred actions up to given `max_depth`
"""
seq = []
TreeDebugger._preferred_actions_helper(root, 0, seq, max_depth=max_depth)
return seq
@staticmethod
def _preferred_actions_helper(root, depth, seq, max_depth=None):
# don't care about last layer action because it's outside of planning
# horizon and only has initial value.
if max_depth is not None and depth > max_depth:
return
if root is None or len(root.children) == 0:
return
best_child = root.to_edge(0)
best_value = root[0].value
for c in root.children:
if root[c].value > best_value:
best_child = c
best_value = root[c].value
seq.append(best_child)
equally_good = []
if isinstance(root, VNode):
for c in root.children:
if not (c == best_child) and root[c].value == best_value:
equally_good.append(c)
if best_child is not None and root[best_child] is not None:
if isinstance(root[best_child], QNode):
print(" %s %s" % (typ.yellow(str(best_child)), str(equally_good)))
next_depth = depth
else:
next_depth = depth + 1
TreeDebugger._preferred_actions_helper(
root[best_child], next_depth, seq, max_depth=max_depth
)
[docs]
def path(self, dest):
"""alias for path_to;
Example usage:
marking path from root to the first node on the second layer:
dd.mark(dd.path(dd.layer(2)[0]))
"""
return self.path_to(dest)
[docs]
def path_to(self, dest):
"""Returns a list of keys (actions / observations) that represents the path from
self.current to the given node `dest`. Returns None if the path does not
exist. Uses DFS. Can be useful for marking path to a node to a specific
layer. Note that the returned path is a list of keys (i.e. edges), not nodes.
"""
# dest may be in the returned list of layer() which could be a TreeDebugger.
if isinstance(dest, TreeDebugger):
dest = dest.current
worklist = [self.current]
seen = set({self.current})
parent = {self.current: None}
while len(worklist) > 0:
node = worklist.pop()
if node == dest:
return self._get_path(self.current, dest, parent)
for c in node.children:
if node[c] not in seen:
worklist.append(node[c])
seen.add(node[c])
parent[node[c]] = (node, c)
return None
def _get_path(self, start, dest, parent):
"""Helper method for path_to"""
v = dest
path = []
while v != start:
v, edge = parent[v]
path.append(edge)
return list(reversed(path))
[docs]
@staticmethod
def tree_stats(root, max_depth=None):
"""Gether statistics about the tree"""
stats = {
"total_vnodes": 0,
"total_qnodes": 0,
"total_vnodes_children": 0,
"total_qnodes_children": 0,
"max_vnodes_children": 0,
"max_qnodes_children": 0,
"max_depth": 0,
}
TreeDebugger._tree_stats_helper(root, 0, stats, max_depth=max_depth)
stats["num_visits"] = root.num_visits
stats["value"] = root.value
return stats
@staticmethod
def _tree_stats_helper(root, depth, stats, max_depth=None):
if max_depth is not None and depth > max_depth:
return
else:
if isinstance(root, VNode):
stats["total_vnodes"] += 1
stats["total_vnodes_children"] += len(root.children)
stats["max_vnodes_children"] = max(
stats["max_vnodes_children"], len(root.children)
)
stats["max_depth"] = max(stats["max_depth"], depth)
else:
stats["total_qnodes"] += 1
stats["total_qnodes_children"] += len(root.children)
stats["max_qnodes_children"] = max(
stats["max_qnodes_children"], len(root.children)
)
for c in root.children:
if isinstance(root[c], QNode):
next_depth = depth
else:
next_depth = depth + 1
TreeDebugger._tree_stats_helper(
root[c], next_depth, stats, max_depth=max_depth
)
[docs]
def sorted_by_str(enumerable):
return sorted(enumerable, key=lambda n: str(n))
[docs]
def interpret_color(colorstr):
if colorstr.lower() in typ.colors:
return eval("typ.{}".format(colorstr))
else:
raise ValueError(
"Invalid color: {};\nThe available ones are {}".format(colorstr, typ.colors)
)