From edde3cc8e56dbb4d4988a60adfd1cfa8be1ede1b Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Thu, 2 Nov 2023 16:26:20 -0400 Subject: [PATCH] Implemented AStar and rewrote BFS to use AStar algorithm --- .../org/rairlab/planner/AStarPlanner.java | 161 ++++++++++++++++++ src/main/java/org/rairlab/planner/Action.java | 33 ++-- .../rairlab/planner/BreadthFirstPlanner.java | 113 +++--------- .../java/org/rairlab/planner/Operations.java | 18 +- src/main/java/org/rairlab/planner/Plan.java | 6 + .../planner/heuristics/ConstantHeuristic.java | 9 + .../planner/utils/IndefiniteAction.java | 2 +- .../planner/utils/PlanningProblem.java | 9 +- .../org/rairlab/planner/utils/Runner.java | 18 +- 9 files changed, 248 insertions(+), 121 deletions(-) create mode 100644 src/main/java/org/rairlab/planner/AStarPlanner.java create mode 100644 src/main/java/org/rairlab/planner/heuristics/ConstantHeuristic.java diff --git a/src/main/java/org/rairlab/planner/AStarPlanner.java b/src/main/java/org/rairlab/planner/AStarPlanner.java new file mode 100644 index 0000000..38b299b --- /dev/null +++ b/src/main/java/org/rairlab/planner/AStarPlanner.java @@ -0,0 +1,161 @@ +package org.rairlab.planner; + +import org.rairlab.shadow.prover.representations.formula.Formula; + +import java.util.*; +import java.util.function.Function; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.tuple.Pair; + +class AStarComparator implements Comparator>> { + private Map>, Integer> heuristic; + + public AStarComparator() { + this.heuristic = new HashMap>, Integer>(); + } + + @Override + public int compare(Pair> o1, Pair> o2) { + // Print nag message if undefined behavior is happening + if (!this.heuristic.containsKey(o1) || !this.heuristic.containsKey(o2)) { + System.out.println("[ERROR] Heuristic is not defined for state"); + } + + int i1 = this.heuristic.get(o1); + int i2 = this.heuristic.get(o2); + return i1 < i2 ? -1: 1; + } + + public void setValue(Pair> k, int v) { + this.heuristic.put(k, v); + } + + public int getValue(Pair> k) { + return this.heuristic.get(k); + } +} + +/** + * Created by brandonrozek on 03/29/2023. + */ +public class AStarPlanner { + + // The longest plan to search for, -1 means no bound + private Optional MAX_DEPTH = Optional.empty(); + // Number of plans to look for, -1 means up to max_depth + private Optional K = Optional.empty(); + + public AStarPlanner(){ } + + public Set plan(Set background, Set actions, State start, State goal, Function heuristic) { + + // Search Space Data Structures + Set history = new HashSet(); + // Each node in the search space consists of + // (state, sequence of actions from initial) + AStarComparator comparator = new AStarComparator(); + Queue>> search = new PriorityQueue>>(comparator); + + // Submit Initial State + Pair> searchStart = Pair.of(start, new ArrayList()); + comparator.setValue(searchStart, 0); + search.add(searchStart); + + // Current set of plans + Set plansFound = new HashSet(); + + // AStar Traversal until + // - No more actions can be applied + // - Max depth reached + // - Found K plans + while (!search.isEmpty()) { + + + Pair> currentSearch = search.remove(); + State lastState = currentSearch.getLeft(); + List previous_actions = currentSearch.getRight(); + + // System.out.println("Considering state with heuristic: " + comparator.getValue(currentSearch)); + + // Exit loop if we've passed the depth limit + int currentDepth = previous_actions.size(); + if (MAX_DEPTH.isPresent() && currentDepth > MAX_DEPTH.get()) { + break; + } + + // If we're at the goal return + if (Operations.satisfies(background, lastState, goal)) { + plansFound.add(new Plan(previous_actions)); + if (K.isPresent() && plansFound.size() >= K.get()) { + break; + } + continue; + } + + // Only consider non-trivial actions + Set nonTrivialActions = actions.stream() + .filter(Action::isNonTrivial) + .collect(Collectors.toSet()); + + // Apply the action to the state and add to the search space + for (Action action : nonTrivialActions) { + Optional>> optNextStateActionPairs = Operations.apply(background, action, lastState); + + // Ignore actions that aren't applicable + if (optNextStateActionPairs.isEmpty()) { + continue; + } + + // Action's aren't grounded so each nextState represents + // a different parameter binding + Set> nextStateActionPairs = optNextStateActionPairs.get(); + for (Pair stateActionPair: nextStateActionPairs) { + State nextState = stateActionPair.getLeft(); + Action nextAction = stateActionPair.getRight(); + + // Prune already visited states + if (history.contains(nextState)) { + continue; + } + + // Add to history + history.add(nextState); + + // Construct search space parameters + List next_actions = new ArrayList(previous_actions); + next_actions.add(nextAction); + + // Add to search space + Pair> futureSearch = Pair.of(nextState, next_actions); + int planCost = next_actions.stream().map(Action::getCost).reduce(0, (a, b) -> a + b); + int heuristicValue = heuristic.apply(nextState); + comparator.setValue(futureSearch, planCost + heuristicValue); + search.add(futureSearch); + } + } + } + + return plansFound; + } + + public Optional getMaxDepth() { + return MAX_DEPTH; + } + + public void setMaxDepth(int maxDepth) { + MAX_DEPTH = Optional.of(maxDepth); + } + + public void setK(int k) { + K = Optional.of(k); + } + + public void clearK() { + K = Optional.empty(); + } + + public Optional getK() { + return K; + } +} diff --git a/src/main/java/org/rairlab/planner/Action.java b/src/main/java/org/rairlab/planner/Action.java index e6f315b..7d4d0f0 100644 --- a/src/main/java/org/rairlab/planner/Action.java +++ b/src/main/java/org/rairlab/planner/Action.java @@ -13,6 +13,7 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import java.util.ArrayList; /** * Created by naveensundarg on 1/13/17. @@ -28,12 +29,13 @@ public class Action { private final String name; private final Formula precondition; + private int cost; private int weight; private final boolean trivial; private final Compound shorthand; - public Action(String name, Set preconditions, Set additions, Set deletions, List freeVariables, List interestedVars) { + public Action(String name, Set preconditions, Set additions, Set deletions, int cost, List freeVariables, List interestedVars) { this.name = name; this.preconditions = preconditions; @@ -52,6 +54,7 @@ public Action(String name, Set preconditions, Set additions, S this.weight = preconditions.stream().mapToInt(Formula::getWeight).sum() + additions.stream().mapToInt(Formula::getWeight).sum() + deletions.stream().mapToInt(Formula::getWeight).sum(); + this.cost = cost; List valuesList = interestedVars.stream().collect(Collectors.toList());; this.shorthand = new Compound(name, valuesList); @@ -61,7 +64,7 @@ public Action(String name, Set preconditions, Set additions, S } public Action(String name, Set preconditions, Set additions, - Set deletions, List freeVariables, + Set deletions, int cost, List freeVariables, Compound shorthand ) { this.name = name; @@ -82,6 +85,7 @@ public Action(String name, Set preconditions, Set additions, this.weight = preconditions.stream().mapToInt(Formula::getWeight).sum() + additions.stream().mapToInt(Formula::getWeight).sum() + deletions.stream().mapToInt(Formula::getWeight).sum(); + this.cost = cost; this.shorthand = shorthand; this.trivial = computeTrivialOrNot(); @@ -94,9 +98,10 @@ public static Action buildActionFrom(String name, Set preconditions, Set additions, Set deletions, + int cost, List freeVariables) { - return new Action(name, preconditions, additions, deletions, freeVariables, freeVariables); + return new Action(name, preconditions, additions, deletions, cost, freeVariables, freeVariables); } @@ -104,9 +109,10 @@ public static Action buildActionFrom(String name, Set preconditions, Set additions, Set deletions, + int cost, List freeVariables, List interestedVars) { - return new Action(name, preconditions, additions, deletions, freeVariables, interestedVars); + return new Action(name, preconditions, additions, deletions, cost, freeVariables, interestedVars); } @@ -114,6 +120,10 @@ public int getWeight() { return weight; } + public int getCost() { + return cost; + } + public Formula getPrecondition() { return precondition; } @@ -131,16 +141,11 @@ public Set getDeletions() { } public List openVars() { + return freeVariables; + } - Set variables = Sets.newSet(); - - variables.addAll(freeVariables); - - List variablesList = CollectionUtils.newEmptyList(); - - variablesList.addAll(variables); - return variablesList; - + public List getInterestedVars() { + return interestedVars; } public Set instantiateAdditions(Map mapping) { @@ -172,7 +177,7 @@ public Action instantiate(Map binding){ List valuesList = interestedVars.stream().collect(Collectors.toList());; Compound shorthand = (Compound)(new Compound(name, valuesList)).apply(binding); - return new Action(name, newPreconditions, newAdditions, newDeletions, newFreeVariables, shorthand); + return new Action(name, newPreconditions, newAdditions, newDeletions, cost, newFreeVariables, shorthand); } public String getName() { diff --git a/src/main/java/org/rairlab/planner/BreadthFirstPlanner.java b/src/main/java/org/rairlab/planner/BreadthFirstPlanner.java index f366fda..b2c32d6 100644 --- a/src/main/java/org/rairlab/planner/BreadthFirstPlanner.java +++ b/src/main/java/org/rairlab/planner/BreadthFirstPlanner.java @@ -1,127 +1,58 @@ package org.rairlab.planner; import org.rairlab.shadow.prover.representations.formula.Formula; +import org.rairlab.planner.Action; import java.util.*; -import java.util.stream.Collectors; -import org.apache.commons.lang3.tuple.Pair; /** * Created by brandonrozek on 03/29/2023. */ public class BreadthFirstPlanner { - // The longest plan to search for, -1 means no bound - private Optional MAX_DEPTH = Optional.empty(); - // Number of plans to look for, -1 means up to max_depth - private Optional K = Optional.empty(); + private AStarPlanner planner; - public BreadthFirstPlanner(){ } + public BreadthFirstPlanner(){ + planner = new AStarPlanner(); + } + + public static int h(State s) { + return 1; + } public Set plan(Set background, Set actions, State start, State goal) { - // Search Space Data Structures - Set history = new HashSet(); - // Each node in the search space consists of - // (state, sequence of actions from initial) - Queue, List>> search = new ArrayDeque,List>>(); - - // Submit Initial State - search.add(Pair.of(List.of(start), new ArrayList())); - - // Current set of plans - Set plansFound = new HashSet(); - - // Breadth First Traversal until - // - No more actions can be applied - // - Max depth reached - // - Found K plans - while (!search.isEmpty()) { - - Pair, List> currentSearch = search.remove(); - List previous_states = currentSearch.getLeft(); - List previous_actions = currentSearch.getRight(); - State lastState = previous_states.get(previous_states.size() - 1); - - // Exit loop if we've passed the depth limit - int currentDepth = previous_actions.size(); - if (MAX_DEPTH.isPresent() && currentDepth > MAX_DEPTH.get()) { - break; - } - - // If we're at the goal return - if (Operations.satisfies(background, lastState, goal)) { - plansFound.add(new Plan(previous_actions, previous_states, background)); - if (K.isPresent() && plansFound.size() >= K.get()) { - break; - } - continue; - } - - // Only consider non-trivial actions - Set nonTrivialActions = actions.stream() - .filter(Action::isNonTrivial) - .collect(Collectors.toSet()); - - // Apply the action to the state and add to the search space - for (Action action : nonTrivialActions) { - Optional>> optNextStateActionPairs = Operations.apply(background, action, lastState); - - // Ignore actions that aren't applicable - if (optNextStateActionPairs.isEmpty()) { - continue; - } - - // Action's aren't grounded so each nextState represents - // a different parameter binding - Set> nextStateActionPairs = optNextStateActionPairs.get(); - for (Pair stateActionPair: nextStateActionPairs) { - State nextState = stateActionPair.getLeft(); - Action nextAction = stateActionPair.getRight(); - - // Prune already visited states - if (history.contains(nextState)) { - continue; - } - - // Add to history - history.add(nextState); - - // Construct search space parameters - List next_states = new ArrayList(previous_states); - next_states.add(nextState); - - List next_actions = new ArrayList(previous_actions); - next_actions.add(nextAction); - - // Add to search space - search.add(Pair.of(next_states, next_actions)); - } - } + // For BFS, need to ignore action costs + Set newActions = new HashSet(); + for (Action a : actions) { + newActions.add(new Action( + a.getName(), a.getPreconditions(), a.getAdditions(), a.getDeletions(), + 1, a.openVars(), a.getInterestedVars() + )); } - return plansFound; + return planner.plan(background, actions, start, goal, BreadthFirstPlanner::h); } public Optional getMaxDepth() { - return MAX_DEPTH; + return planner.getMaxDepth(); } public void setMaxDepth(int maxDepth) { - MAX_DEPTH = Optional.of(maxDepth); + planner.setMaxDepth(maxDepth); } public void setK(int k) { - K = Optional.of(k); + planner.setK(k); } public void clearK() { - K = Optional.empty(); + planner.clearK(); } public Optional getK() { - return K; + return planner.getK(); } } \ No newline at end of file diff --git a/src/main/java/org/rairlab/planner/Operations.java b/src/main/java/org/rairlab/planner/Operations.java index df0510b..f31e145 100644 --- a/src/main/java/org/rairlab/planner/Operations.java +++ b/src/main/java/org/rairlab/planner/Operations.java @@ -4,10 +4,8 @@ import org.rairlab.shadow.prover.core.Prover; import org.rairlab.shadow.prover.core.SnarkWrapper; import org.rairlab.shadow.prover.core.proof.Justification; -import org.rairlab.planner.utils.Commons; import org.rairlab.shadow.prover.representations.formula.BiConditional; import org.rairlab.shadow.prover.representations.formula.Formula; -import org.rairlab.shadow.prover.representations.formula.Predicate; import org.rairlab.shadow.prover.representations.value.Value; import org.rairlab.shadow.prover.representations.value.Variable; @@ -23,7 +21,6 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.*; -import java.util.stream.Collectors; /** * Created by naveensundarg on 1/13/17. @@ -225,10 +222,17 @@ public static boolean satisfies(Set background, State state, State goal return true; } - return proveCached( - Sets.union(background, state.getFormulae()), - Commons.makeAnd(goal.getFormulae()) - ).isPresent(); + for (Formula g : goal.getFormulae()) { + Optional just = proveCached( + Sets.union(background, state.getFormulae()), + g + ); + if (just.isEmpty()) { + return false; + } + } + + return true; } public static boolean conflicts(Set background, State state1, State state2) { diff --git a/src/main/java/org/rairlab/planner/Plan.java b/src/main/java/org/rairlab/planner/Plan.java index c531adb..2a9cfb0 100644 --- a/src/main/java/org/rairlab/planner/Plan.java +++ b/src/main/java/org/rairlab/planner/Plan.java @@ -28,6 +28,12 @@ public Plan(List actions, List expectedStates, Set backg this.background = background; } + public Plan(List actions) { + this.actions = actions; + this.expectedStates = CollectionUtils.newEmptyList(); + this.background = CollectionUtils.newEmptySet(); + } + public List getActions() { return actions; } diff --git a/src/main/java/org/rairlab/planner/heuristics/ConstantHeuristic.java b/src/main/java/org/rairlab/planner/heuristics/ConstantHeuristic.java new file mode 100644 index 0000000..f4bb132 --- /dev/null +++ b/src/main/java/org/rairlab/planner/heuristics/ConstantHeuristic.java @@ -0,0 +1,9 @@ +package org.rairlab.planner.heuristics; + +import org.rairlab.planner.State; + +public class ConstantHeuristic { + public static int h(State s) { + return 1; + } +} diff --git a/src/main/java/org/rairlab/planner/utils/IndefiniteAction.java b/src/main/java/org/rairlab/planner/utils/IndefiniteAction.java index 1a4783f..2f07f34 100644 --- a/src/main/java/org/rairlab/planner/utils/IndefiniteAction.java +++ b/src/main/java/org/rairlab/planner/utils/IndefiniteAction.java @@ -10,6 +10,6 @@ public class IndefiniteAction extends Action { private IndefiniteAction(String name, Set preconditions, Set additions, Set deletions, List freeVariables) { - super(name, preconditions, additions, deletions, freeVariables, freeVariables); + super(name, preconditions, additions, deletions, 1, freeVariables, freeVariables); } } diff --git a/src/main/java/org/rairlab/planner/utils/PlanningProblem.java b/src/main/java/org/rairlab/planner/utils/PlanningProblem.java index 5817b93..ca5d6a7 100644 --- a/src/main/java/org/rairlab/planner/utils/PlanningProblem.java +++ b/src/main/java/org/rairlab/planner/utils/PlanningProblem.java @@ -58,6 +58,7 @@ public class PlanningProblem { private static final Keyword PRECONDITIONS = Keyword.newKeyword("preconditions"); private static final Keyword ADDITIONS = Keyword.newKeyword("additions"); private static final Keyword DELETIONS = Keyword.newKeyword("deletions"); + private static final Keyword COST = Keyword.newKeyword("cost"); private static final Symbol ACTION_DEFINER = Symbol.newSymbol("define-action"); @@ -311,11 +312,17 @@ private static Set readActionsFrom(List actionSpecs) throws Reader.Pa Set preconditions = readFrom((List) actionSpec.get(PRECONDITIONS)); Set additions = readFrom((List) actionSpec.get(ADDITIONS)); Set deletions = readFrom((List) actionSpec.get(DELETIONS)); + int cost; + if (actionSpec.containsKey(COST)) { + cost = Integer.parseInt(actionSpec.get(COST).toString()); + } else { + cost = 1; + } List interestedVars = CollectionUtils.newEmptyList(); interestedVars.addAll(vars); vars.addAll(preconditions.stream().map(Formula::variablesPresent).reduce(Sets.newSet(), Sets::union)); - return Action.buildActionFrom(name, preconditions, additions, deletions, vars, interestedVars); + return Action.buildActionFrom(name, preconditions, additions, deletions, cost, vars, interestedVars); } catch (Reader.ParsingException e) { diff --git a/src/main/java/org/rairlab/planner/utils/Runner.java b/src/main/java/org/rairlab/planner/utils/Runner.java index 8cd77e7..9f0d65d 100644 --- a/src/main/java/org/rairlab/planner/utils/Runner.java +++ b/src/main/java/org/rairlab/planner/utils/Runner.java @@ -1,8 +1,8 @@ package org.rairlab.planner.utils; -import org.rairlab.planner.BreadthFirstPlanner; +import org.rairlab.planner.AStarPlanner; import org.rairlab.planner.Plan; -import org.rairlab.planner.Planner; +import org.rairlab.planner.heuristics.ConstantHeuristic; import org.rairlab.shadow.prover.utils.Reader; import java.io.FileInputStream; @@ -10,6 +10,7 @@ import java.util.*; + public final class Runner { public static void main(String[] args) { @@ -43,16 +44,19 @@ public static void main(String[] args) { e.printStackTrace(); return; } - - BreadthFirstPlanner breadthFirstPlanner = new BreadthFirstPlanner(); - breadthFirstPlanner.setK(2); + + AStarPlanner astarplanner = new AStarPlanner(); + astarplanner.setK(2); for (PlanningProblem planningProblem : planningProblemList) { - Set plans = breadthFirstPlanner.plan( + + Set plans = astarplanner.plan( planningProblem.getBackground(), planningProblem.getActions(), planningProblem.getStart(), - planningProblem.getGoal()); + planningProblem.getGoal(), + ConstantHeuristic::h + ); if(plans.size() > 0) { System.out.println(plans.toString());