From dd670961cb1fb8b62c28adeb757f2a3c7dc6a757 Mon Sep 17 00:00:00 2001 From: fred-labs Date: Mon, 9 Sep 2024 16:12:46 +0200 Subject: [PATCH] action_call: fix shutdown (#182) --- .../scenario_execution/__init__.py | 5 +++-- .../scenario_execution_base.py | 18 ++++++++++++++++++ .../actions/ros_action_call.py | 10 +++++++--- .../scenario_execution_ros.py | 8 +++++--- 4 files changed, 33 insertions(+), 8 deletions(-) diff --git a/scenario_execution/scenario_execution/__init__.py b/scenario_execution/scenario_execution/__init__.py index f4cfe100..52c90eea 100644 --- a/scenario_execution/scenario_execution/__init__.py +++ b/scenario_execution/scenario_execution/__init__.py @@ -17,7 +17,7 @@ from . import actions from . import utils from . import model -from scenario_execution.scenario_execution_base import ScenarioExecution +from scenario_execution.scenario_execution_base import ScenarioExecution, ShutdownHandler from scenario_execution.utils.logging import BaseLogger, Logger __all__ = [ @@ -26,5 +26,6 @@ 'model', 'BaseLogger', "Logger", - 'ScenarioExecution' + 'ScenarioExecution', + 'ShutdownHandler' ] diff --git a/scenario_execution/scenario_execution/scenario_execution_base.py b/scenario_execution/scenario_execution/scenario_execution_base.py index 6c85412e..b0ff27ca 100644 --- a/scenario_execution/scenario_execution/scenario_execution_base.py +++ b/scenario_execution/scenario_execution/scenario_execution_base.py @@ -29,6 +29,24 @@ from timeit import default_timer as timer +class ShutdownHandler: + _instance = None + + def __init__(self): + self.futures = [] + + def get_instance(): # pylint: disable=no-method-argument + if ShutdownHandler._instance is None: + ShutdownHandler._instance = ShutdownHandler() + return ShutdownHandler._instance + + def add_future(self, future): + self.futures.append(future) + + def is_done(self): + return all(fut.done() for fut in self.futures) + + @dataclass class ScenarioResult: name: str diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_action_call.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_action_call.py index 43161f6d..00d81326 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_action_call.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_action_call.py @@ -25,6 +25,7 @@ import py_trees # pylint: disable=import-error from action_msgs.msg import GoalStatus from scenario_execution.actions.base_action import BaseAction, ActionError +from scenario_execution import ShutdownHandler class ActionCallActionState(Enum): @@ -152,8 +153,9 @@ def goal_response_callback(self, future): return self.current_state = ActionCallActionState.ACTION_ACCEPTED self.feedback_message = f"Goal accepted." # pylint: disable= attribute-defined-outside-init - get_result_future = self.goal_handle.get_result_async() - get_result_future.add_done_callback(self.get_result_callback) + if not self.success_on_acceptance: + get_result_future = self.goal_handle.get_result_async() + get_result_future.add_done_callback(self.get_result_callback) def get_result_callback(self, future): """ @@ -179,7 +181,9 @@ def get_result_callback(self, future): def shutdown(self): if self.goal_handle: - self.goal_handle.cancel_goal() + future = self.goal_handle.cancel_goal_async() + shutdown_handler = ShutdownHandler.get_instance() + shutdown_handler.add_future(future) def get_feedback_message(self, current_state): feedback_message = None diff --git a/scenario_execution_ros/scenario_execution_ros/scenario_execution_ros.py b/scenario_execution_ros/scenario_execution_ros/scenario_execution_ros.py index 00a93d62..0861bd9f 100644 --- a/scenario_execution_ros/scenario_execution_ros/scenario_execution_ros.py +++ b/scenario_execution_ros/scenario_execution_ros/scenario_execution_ros.py @@ -19,7 +19,7 @@ import rclpy # pylint: disable=import-error import py_trees_ros # pylint: disable=import-error from py_trees_ros_interfaces.srv import OpenSnapshotStream -from scenario_execution import ScenarioExecution +from scenario_execution import ScenarioExecution, ShutdownHandler from .logging_ros import RosLogger from .marker_handler import MarkerHandler @@ -119,7 +119,10 @@ def run(self) -> bool: self.on_scenario_shutdown(False, "Aborted") if self.shutdown_task is not None and self.shutdown_task.done(): - break + shutdown_handler = ShutdownHandler.get_instance() + if shutdown_handler.is_done(): + self.logger.info("Shutting down finished.") + break except Exception as e: # pylint: disable=broad-except self.on_scenario_shutdown(False, "Run failed", f"{e}") finally: @@ -128,7 +131,6 @@ def run(self) -> bool: def shutdown(self): self.logger.info("Shutting down...") self.behaviour_tree.shutdown() - self.logger.info("Shutting down finished.") def on_scenario_shutdown(self, result, failure_message="", failure_output=""): if self.shutdown_requested: