Source code for pomdp_py.utils.interfaces.solvers

"""
`pomdp_py <https://h2r.github.io/pomdp-py/html/>`_ provides function calls to use external solvers,
given a POMDP defined using pomdp_py interfaces. Currently, we interface with:

* `pomdp-solve <http://www.pomdp.org/code/index.html>`_ by Anthony R. Cassandra
* `SARSOP <https://github.com/AdaCompNUS/sarsop>`_ by NUS

We hope to interface with:

* `POMDP.jl <https://github.com/JuliaPOMDP/POMDPs.jl>`_
* more? Help us if you can!
"""

import pomdp_py
from pomdp_py.utils.interfaces.conversion import (
    to_pomdp_file,
    PolicyGraph,
    AlphaVectorPolicy,
    parse_pomdp_solve_output,
)
import subprocess
import os, sys


[docs] def vi_pruning( agent, pomdp_solve_path, discount_factor=0.95, options=[], pomdp_name="temp-pomdp", remove_generated_files=False, return_policy_graph=False, ): """ Value Iteration with pruning, using the software pomdp-solve https://www.pomdp.org/code/ developed by Anthony R. Cassandra. Args: agent (pomdp_py.Agent): The agent that contains the POMDP definition pomdp_solve_path (str): Path to the `pomdp_solve` binary generated after compiling the pomdp-solve library. options (list): Additional options to pass in to the command line interface. The options should be a list of strings, such as ["-stop_criteria", "weak", ...] Some useful options are: -horizon <int> -time_limit <int> pomdp_name (str): The name used to create the .pomdp file. remove_generated_files (bool): True if after policy is computed, the .pomdp, .alpha, .pg files are removed. Default is False. return_policy_graph (bool): True if return the policy as a PolicyGraph. By default is False, in which case an AlphaVectorPolicy is returned. Returns: PolicyGraph or AlphaVectorPolicy: The policy returned by the solver. """ try: all_states = list(agent.all_states) all_actions = list(agent.all_actions) all_observations = list(agent.all_observations) except NotImplementedError: raise ( "S, A, O must be enumerable for a given agent to convert to .pomdp format" ) pomdp_path = "./%s.pomdp" % pomdp_name to_pomdp_file(agent, pomdp_path, discount_factor=discount_factor) proc = subprocess.Popen( [pomdp_solve_path, "-pomdp", pomdp_path, "-o", pomdp_name] + list(map(str, options)) ) proc.wait() # Read the value and policy graph files alpha_path = "%s.alpha" % pomdp_name pg_path = "%s.pg" % pomdp_name if return_policy_graph: policy = PolicyGraph.construct( alpha_path, pg_path, all_states, all_actions, all_observations ) else: policy = AlphaVectorPolicy.construct( alpha_path, all_states, all_actions, solver="pomdp-solve" ) # Remove temporary files if remove_generated_files: os.remove(pomdp_path) os.remove(alpha_path) os.remove(pg_path) return policy
[docs] def sarsop( agent, pomdpsol_path, discount_factor=0.95, timeout=30, memory=100, precision=0.5, pomdp_name="temp-pomdp", remove_generated_files=False, logfile=None, ): """ SARSOP, using the binary from https://github.com/AdaCompNUS/sarsop This is an anytime POMDP planning algorithm Args: agent (pomdp_py.Agent): The agent that defines the POMDP models pomdpsol_path (str): Path to the `pomdpsol` binary timeout (int): The time limit (seconds) to run the algorithm until termination memory (int): The memory size (mb) to run the algorithm until termination precision (float): solver runs until regret is less than `precision` pomdp_name (str): Name of the .pomdp file that will be created when solving remove_generated_files (bool): Remove created files during solving after finish. logfile (str): Path to file to write the log of both stdout and stderr Returns: AlphaVectorPolicy: The policy returned by the solver. """ try: all_states = list(agent.all_states) all_actions = list(agent.all_actions) all_observations = list(agent.all_observations) except NotImplementedError: raise ( "S, A, O must be enumerable for a given agent to convert to .pomdpx format" ) if logfile is None: stdout = None stderr = None else: logf = open(logfile, "w") stdout = subprocess.PIPE stderr = subprocess.STDOUT pomdp_path = "./%s.pomdp" % pomdp_name to_pomdp_file(agent, pomdp_path, discount_factor=discount_factor) proc = subprocess.Popen( [ pomdpsol_path, "--timeout", str(timeout), "--memory", str(memory), "--precision", str(precision), "--output", "%s.policy" % pomdp_name, pomdp_path, ], stdout=stdout, stderr=stderr, ) if logfile is not None: for line in proc.stdout: line = line.decode("utf-8") sys.stdout.write(line) logf.write(line) proc.wait() policy_path = "%s.policy" % pomdp_name policy = AlphaVectorPolicy.construct(policy_path, all_states, all_actions) # Remove temporary files if remove_generated_files: os.remove(pomdp_path) os.remove(policy_path) if logfile is not None: logf.close() return policy