"""
`pomdp_py <https://h2r.github.io/pomdp-py/html/>`_ provides function calls to use external solvers,
given a POMDP defined using pomdp_py interfaces. Currently, we interface with:
* `pomdp-solve <http://www.pomdp.org/code/index.html>`_ by Anthony R. Cassandra
* `SARSOP <https://github.com/AdaCompNUS/sarsop>`_ by NUS
We hope to interface with:
* `POMDP.jl <https://github.com/JuliaPOMDP/POMDPs.jl>`_
* more? Help us if you can!
"""
import pomdp_py
from pomdp_py.utils.interfaces.conversion import (
to_pomdp_file,
PolicyGraph,
AlphaVectorPolicy,
parse_pomdp_solve_output,
)
import subprocess
import os, sys
[docs]
def vi_pruning(
agent,
pomdp_solve_path,
discount_factor=0.95,
options=[],
pomdp_name="temp-pomdp",
remove_generated_files=False,
return_policy_graph=False,
):
"""
Value Iteration with pruning, using the software pomdp-solve
https://www.pomdp.org/code/ developed by Anthony R. Cassandra.
Args:
agent (pomdp_py.Agent): The agent that contains the POMDP definition
pomdp_solve_path (str): Path to the `pomdp_solve` binary generated after
compiling the pomdp-solve library.
options (list): Additional options to pass in to the command line interface.
The options should be a list of strings, such as ["-stop_criteria", "weak", ...]
Some useful options are:
-horizon <int>
-time_limit <int>
pomdp_name (str): The name used to create the .pomdp file.
remove_generated_files (bool): True if after policy is computed,
the .pomdp, .alpha, .pg files are removed. Default is False.
return_policy_graph (bool): True if return the policy as a PolicyGraph.
By default is False, in which case an AlphaVectorPolicy is returned.
Returns:
PolicyGraph or AlphaVectorPolicy: The policy returned by the solver.
"""
try:
all_states = list(agent.all_states)
all_actions = list(agent.all_actions)
all_observations = list(agent.all_observations)
except NotImplementedError:
raise (
"S, A, O must be enumerable for a given agent to convert to .pomdp format"
)
pomdp_path = "./%s.pomdp" % pomdp_name
to_pomdp_file(agent, pomdp_path, discount_factor=discount_factor)
proc = subprocess.Popen(
[pomdp_solve_path, "-pomdp", pomdp_path, "-o", pomdp_name]
+ list(map(str, options))
)
proc.wait()
# Read the value and policy graph files
alpha_path = "%s.alpha" % pomdp_name
pg_path = "%s.pg" % pomdp_name
if return_policy_graph:
policy = PolicyGraph.construct(
alpha_path, pg_path, all_states, all_actions, all_observations
)
else:
policy = AlphaVectorPolicy.construct(
alpha_path, all_states, all_actions, solver="pomdp-solve"
)
# Remove temporary files
if remove_generated_files:
os.remove(pomdp_path)
os.remove(alpha_path)
os.remove(pg_path)
return policy
[docs]
def sarsop(
agent,
pomdpsol_path,
discount_factor=0.95,
timeout=30,
memory=100,
precision=0.5,
pomdp_name="temp-pomdp",
remove_generated_files=False,
logfile=None,
):
"""
SARSOP, using the binary from https://github.com/AdaCompNUS/sarsop
This is an anytime POMDP planning algorithm
Args:
agent (pomdp_py.Agent): The agent that defines the POMDP models
pomdpsol_path (str): Path to the `pomdpsol` binary
timeout (int): The time limit (seconds) to run the algorithm until termination
memory (int): The memory size (mb) to run the algorithm until termination
precision (float): solver runs until regret is less than `precision`
pomdp_name (str): Name of the .pomdp file that will be created when solving
remove_generated_files (bool): Remove created files during solving after finish.
logfile (str): Path to file to write the log of both stdout and stderr
Returns:
AlphaVectorPolicy: The policy returned by the solver.
"""
try:
all_states = list(agent.all_states)
all_actions = list(agent.all_actions)
all_observations = list(agent.all_observations)
except NotImplementedError:
raise (
"S, A, O must be enumerable for a given agent to convert to .pomdpx format"
)
if logfile is None:
stdout = None
stderr = None
else:
logf = open(logfile, "w")
stdout = subprocess.PIPE
stderr = subprocess.STDOUT
pomdp_path = "./%s.pomdp" % pomdp_name
to_pomdp_file(agent, pomdp_path, discount_factor=discount_factor)
proc = subprocess.Popen(
[
pomdpsol_path,
"--timeout",
str(timeout),
"--memory",
str(memory),
"--precision",
str(precision),
"--output",
"%s.policy" % pomdp_name,
pomdp_path,
],
stdout=stdout,
stderr=stderr,
)
if logfile is not None:
for line in proc.stdout:
line = line.decode("utf-8")
sys.stdout.write(line)
logf.write(line)
proc.wait()
policy_path = "%s.policy" % pomdp_name
policy = AlphaVectorPolicy.construct(policy_path, all_states, all_actions)
# Remove temporary files
if remove_generated_files:
os.remove(pomdp_path)
os.remove(policy_path)
if logfile is not None:
logf.close()
return policy