Skip to content

Commit

Permalink
[issue1127] enforce binary encoding of atoms occuring negated in the …
Browse files Browse the repository at this point in the history
…goal
  • Loading branch information
roeger committed Jan 31, 2024
1 parent 2349940 commit 1230732
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
11 changes: 8 additions & 3 deletions src/translate/fact_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def _update_top(self):
self.groups_by_size[len(candidate)].append(candidate)
self.max_size -= 1

def choose_groups(groups, reachable_facts):
def choose_groups(groups, reachable_facts, negative_in_goal):
if negative_in_goal:
# we remove atoms that occur negatively in the goal from the groups to
# enforce them to be encoded with a binary variable.
groups = [set(group) - negative_in_goal for group in groups]
queue = GroupCoverQueue(groups)
uncovered_facts = reachable_facts.copy()
result = []
Expand Down Expand Up @@ -107,7 +111,8 @@ def sort_groups(groups):
return sorted(sorted(group) for group in groups)

def compute_groups(task: pddl.Task, atoms: Set[pddl.Literal],
reachable_action_params: Dict[pddl.Action, List[str]]) -> Tuple[
reachable_action_params: Dict[pddl.Action, List[str]],
negative_in_goal: Set[pddl.Atom]) -> Tuple[
List[List[pddl.Atom]], # groups
# -> all selected mutex groups plus singleton groups for uncovered facts
List[List[pddl.Atom]], # mutex_groups
Expand All @@ -128,7 +133,7 @@ def compute_groups(task: pddl.Task, atoms: Set[pddl.Literal],
with timers.timing("Collecting mutex groups"):
mutex_groups = collect_all_mutex_groups(groups, atoms)
with timers.timing("Choosing groups", block=True):
groups = choose_groups(groups, atoms)
groups = choose_groups(groups, atoms, negative_in_goal)
groups = sort_groups(groups)
with timers.timing("Building translation key"):
translation_key = build_translation_key(groups)
Expand Down
5 changes: 4 additions & 1 deletion src/translate/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,12 +554,15 @@ def pddl_to_sas(task):
elif goal_list is None:
return unsolvable_sas_task("Trivially false goal")

negative_in_goal = set()
for item in goal_list:
assert isinstance(item, pddl.Literal)
if item.negated:
negative_in_goal.add(item.negate())

with timers.timing("Computing fact groups", block=True):
groups, mutex_groups, translation_key = fact_groups.compute_groups(
task, atoms, reachable_action_params)
task, atoms, reachable_action_params, negative_in_goal)

with timers.timing("Building STRIPS to SAS dictionary"):
ranges, strips_to_sas = strips_to_sas_dictionary(
Expand Down

0 comments on commit 1230732

Please sign in to comment.