Skip to content

Commit

Permalink
Implemented AStar and rewrote BFS to use AStar algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
Brandon-Rozek committed Nov 2, 2023
1 parent 2f08f98 commit edde3cc
Show file tree
Hide file tree
Showing 9 changed files with 248 additions and 121 deletions.
161 changes: 161 additions & 0 deletions src/main/java/org/rairlab/planner/AStarPlanner.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package org.rairlab.planner;

import org.rairlab.shadow.prover.representations.formula.Formula;

import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.apache.commons.lang3.tuple.Pair;

class AStarComparator implements Comparator<Pair<State, List<Action>>> {
private Map<Pair<State, List<Action>>, Integer> heuristic;

public AStarComparator() {
this.heuristic = new HashMap<Pair<State, List<Action>>, Integer>();
}

@Override
public int compare(Pair<State, List<Action>> o1, Pair<State, List<Action>> o2) {
// Print nag message if undefined behavior is happening
if (!this.heuristic.containsKey(o1) || !this.heuristic.containsKey(o2)) {
System.out.println("[ERROR] Heuristic is not defined for state");
}

int i1 = this.heuristic.get(o1);
int i2 = this.heuristic.get(o2);
return i1 < i2 ? -1: 1;
}

public void setValue(Pair<State, List<Action>> k, int v) {
this.heuristic.put(k, v);
}

public int getValue(Pair<State, List<Action>> k) {
return this.heuristic.get(k);
}
}

/**
* Created by brandonrozek on 03/29/2023.
*/
public class AStarPlanner {

// The longest plan to search for, -1 means no bound
private Optional<Integer> MAX_DEPTH = Optional.empty();
// Number of plans to look for, -1 means up to max_depth
private Optional<Integer> K = Optional.empty();

public AStarPlanner(){ }

public Set<Plan> plan(Set<Formula> background, Set<Action> actions, State start, State goal, Function<State, Integer> heuristic) {

// Search Space Data Structures
Set<State> history = new HashSet<State>();
// Each node in the search space consists of
// (state, sequence of actions from initial)
AStarComparator comparator = new AStarComparator();
Queue<Pair<State, List<Action>>> search = new PriorityQueue<Pair<State,List<Action>>>(comparator);

// Submit Initial State
Pair<State, List<Action>> searchStart = Pair.of(start, new ArrayList<Action>());
comparator.setValue(searchStart, 0);
search.add(searchStart);

// Current set of plans
Set<Plan> plansFound = new HashSet<Plan>();

// AStar Traversal until
// - No more actions can be applied
// - Max depth reached
// - Found K plans
while (!search.isEmpty()) {


Pair<State, List<Action>> currentSearch = search.remove();
State lastState = currentSearch.getLeft();
List<Action> previous_actions = currentSearch.getRight();

// System.out.println("Considering state with heuristic: " + comparator.getValue(currentSearch));

// Exit loop if we've passed the depth limit
int currentDepth = previous_actions.size();
if (MAX_DEPTH.isPresent() && currentDepth > MAX_DEPTH.get()) {
break;
}

// If we're at the goal return
if (Operations.satisfies(background, lastState, goal)) {
plansFound.add(new Plan(previous_actions));
if (K.isPresent() && plansFound.size() >= K.get()) {
break;
}
continue;
}

// Only consider non-trivial actions
Set<Action> nonTrivialActions = actions.stream()
.filter(Action::isNonTrivial)
.collect(Collectors.toSet());

// Apply the action to the state and add to the search space
for (Action action : nonTrivialActions) {
Optional<Set<Pair<State, Action>>> optNextStateActionPairs = Operations.apply(background, action, lastState);

// Ignore actions that aren't applicable
if (optNextStateActionPairs.isEmpty()) {
continue;
}

// Action's aren't grounded so each nextState represents
// a different parameter binding
Set<Pair<State, Action>> nextStateActionPairs = optNextStateActionPairs.get();
for (Pair<State, Action> stateActionPair: nextStateActionPairs) {
State nextState = stateActionPair.getLeft();
Action nextAction = stateActionPair.getRight();

// Prune already visited states
if (history.contains(nextState)) {
continue;
}

// Add to history
history.add(nextState);

// Construct search space parameters
List<Action> next_actions = new ArrayList<Action>(previous_actions);
next_actions.add(nextAction);

// Add to search space
Pair<State, List<Action>> futureSearch = Pair.of(nextState, next_actions);
int planCost = next_actions.stream().map(Action::getCost).reduce(0, (a, b) -> a + b);
int heuristicValue = heuristic.apply(nextState);
comparator.setValue(futureSearch, planCost + heuristicValue);
search.add(futureSearch);
}
}
}

return plansFound;
}

public Optional<Integer> getMaxDepth() {
return MAX_DEPTH;
}

public void setMaxDepth(int maxDepth) {
MAX_DEPTH = Optional.of(maxDepth);
}

public void setK(int k) {
K = Optional.of(k);
}

public void clearK() {
K = Optional.empty();
}

public Optional<Integer> getK() {
return K;
}
}
33 changes: 19 additions & 14 deletions src/main/java/org/rairlab/planner/Action.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.ArrayList;

/**
* Created by naveensundarg on 1/13/17.
Expand All @@ -28,12 +29,13 @@ public class Action {
private final String name;
private final Formula precondition;

private int cost;
private int weight;
private final boolean trivial;

private final Compound shorthand;

public Action(String name, Set<Formula> preconditions, Set<Formula> additions, Set<Formula> deletions, List<Variable> freeVariables, List<Variable> interestedVars) {
public Action(String name, Set<Formula> preconditions, Set<Formula> additions, Set<Formula> deletions, int cost, List<Variable> freeVariables, List<Variable> interestedVars) {
this.name = name;
this.preconditions = preconditions;

Expand All @@ -52,6 +54,7 @@ public Action(String name, Set<Formula> preconditions, Set<Formula> additions, S
this.weight = preconditions.stream().mapToInt(Formula::getWeight).sum() +
additions.stream().mapToInt(Formula::getWeight).sum() +
deletions.stream().mapToInt(Formula::getWeight).sum();
this.cost = cost;

List<Value> valuesList = interestedVars.stream().collect(Collectors.toList());;
this.shorthand = new Compound(name, valuesList);
Expand All @@ -61,7 +64,7 @@ public Action(String name, Set<Formula> preconditions, Set<Formula> additions, S
}

public Action(String name, Set<Formula> preconditions, Set<Formula> additions,
Set<Formula> deletions, List<Variable> freeVariables,
Set<Formula> deletions, int cost, List<Variable> freeVariables,
Compound shorthand
) {
this.name = name;
Expand All @@ -82,6 +85,7 @@ public Action(String name, Set<Formula> preconditions, Set<Formula> additions,
this.weight = preconditions.stream().mapToInt(Formula::getWeight).sum() +
additions.stream().mapToInt(Formula::getWeight).sum() +
deletions.stream().mapToInt(Formula::getWeight).sum();
this.cost = cost;

this.shorthand = shorthand;
this.trivial = computeTrivialOrNot();
Expand All @@ -94,26 +98,32 @@ public static Action buildActionFrom(String name,
Set<Formula> preconditions,
Set<Formula> additions,
Set<Formula> deletions,
int cost,
List<Variable> freeVariables) {

return new Action(name, preconditions, additions, deletions, freeVariables, freeVariables);
return new Action(name, preconditions, additions, deletions, cost, freeVariables, freeVariables);

}

public static Action buildActionFrom(String name,
Set<Formula> preconditions,
Set<Formula> additions,
Set<Formula> deletions,
int cost,
List<Variable> freeVariables, List<Variable> interestedVars) {

return new Action(name, preconditions, additions, deletions, freeVariables, interestedVars);
return new Action(name, preconditions, additions, deletions, cost, freeVariables, interestedVars);

}

public int getWeight() {
return weight;
}

public int getCost() {
return cost;
}

public Formula getPrecondition() {
return precondition;
}
Expand All @@ -131,16 +141,11 @@ public Set<Formula> getDeletions() {
}

public List<Variable> openVars() {
return freeVariables;
}

Set<Variable> variables = Sets.newSet();

variables.addAll(freeVariables);

List<Variable> variablesList = CollectionUtils.newEmptyList();

variablesList.addAll(variables);
return variablesList;

public List<Variable> getInterestedVars() {
return interestedVars;
}

public Set<Formula> instantiateAdditions(Map<Variable, Value> mapping) {
Expand Down Expand Up @@ -172,7 +177,7 @@ public Action instantiate(Map<Variable, Value> binding){

List<Value> valuesList = interestedVars.stream().collect(Collectors.toList());;
Compound shorthand = (Compound)(new Compound(name, valuesList)).apply(binding);
return new Action(name, newPreconditions, newAdditions, newDeletions, newFreeVariables, shorthand);
return new Action(name, newPreconditions, newAdditions, newDeletions, cost, newFreeVariables, shorthand);
}

public String getName() {
Expand Down
Loading

0 comments on commit edde3cc

Please sign in to comment.