Skip to content

Commit

Permalink
support repeat in all ros action
Browse files Browse the repository at this point in the history
  • Loading branch information
fred-labs committed Aug 13, 2024
1 parent eaa942a commit 3306e82
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def setup(self, **kwargs):

def execute(self, node_name: str, state_sequence: list, allow_initial_skip: bool, fail_on_unexpected: bool, keep_running: bool):
if self.node_name != node_name or self.state_sequence != state_sequence:
raise ValueError(f"Updating node name or state_sequence not supported.")
raise ValueError("Runtime change of arguments 'name', 'state_sequence not supported.")

if all(isinstance(state, tuple) and len(state) == 2 for state in self.state_sequence):
self.state_sequence = [state[0] for state in self.state_sequence]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ def setup(self, **kwargs):
tf_static_topic=(tf_prefix + "/tf_static"),
)

def execute(self, frame_id: str, parent_frame_id: str, timeout: int, threshold_translation: float, threshold_rotation: float, wait_for_first_transform: bool, tf_topic_namespace: str, use_sim_time: bool):
if self.tf_topic_namespace != tf_topic_namespace:
raise ValueError("Runtime change of argument 'tf_topic_namespace' not supported.")
self.frame_id = frame_id
self.parent_frame_id = parent_frame_id
self.timeout = timeout
self.threshold_translation = threshold_translation
self.threshold_rotation = threshold_rotation
self.wait_for_first_transform = wait_for_first_transform
self.use_sim_time = use_sim_time

def update(self) -> py_trees.common.Status:
now = time.time()
transform = self.get_transform(self.frame_id, self.parent_frame_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class AssertTopicLatency(BaseAction):
def __init__(self, topic_name: str, topic_type: str, latency: float, comparison_operator: bool, rolling_average_count: int, wait_for_first_message: bool):
super().__init__()
self.topic_name = topic_name
self.topic_type = topic_type
self.topic_type_str = topic_type
self.latency = latency
self.comparison_operator_feedback = comparison_operator[0]
self.comparison_operator = get_comparison_operator(comparison_operator)
Expand Down Expand Up @@ -59,6 +59,9 @@ def setup(self, **kwargs):
elif not success and not self.wait_for_first_message:
raise ValueError("Topic type must be specified. Please provide a valid topic type.")

def execute(self, topic_name: str, topic_type: str, latency: float, comparison_operator: bool, rolling_average_count: int, wait_for_first_message: bool):
if self.timer != 0:
raise ValueError("Action does not yet support to get retriggered")
self.timer = time.time()

def update(self) -> py_trees.common.Status:
Expand Down Expand Up @@ -98,8 +101,8 @@ def check_topic(self):
for name, topic_type in available_topics:
if name == self.topic_name:
topic_type = topic_type[0].replace('/', '.')
if self.topic_type:
if self.topic_type == topic_type:
if self.topic_type_str:
if self.topic_type_str == topic_type:
self.call_subscriber()
self.is_topic = True
return True
Expand All @@ -121,7 +124,7 @@ def check_topic(self):
return True

def call_subscriber(self):
datatype_in_list = self.topic_type.split(".")
datatype_in_list = self.topic_type_str.split(".")
self.topic_type = getattr(
importlib.import_module(".".join(datatype_in_list[:-1])),
datatype_in_list[-1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def __init__(self, associated_actor, distance: float, namespace_override: str):
self.node = None
self.subscriber = None
self.callback_group = None
if namespace_override:
self.namespace = namespace_override
self.namespace_override = namespace_override

def setup(self, **kwargs):
"""
Expand All @@ -52,8 +51,20 @@ def setup(self, **kwargs):
self.name, self.__class__.__name__)
raise KeyError(error_message) from e
self.callback_group = rclpy.callback_groups.MutuallyExclusiveCallbackGroup()
namespace = self.namespace
if self.namespace_override:
namespace = self.namespace_override
self.subscriber = self.node.create_subscription(
Odometry, self.namespace + '/odom', self._callback, 1000, callback_group=self.callback_group)
Odometry, namespace + '/odom', self._callback, 1000, callback_group=self.callback_group)

def execute(self, associated_actor, distance: float, namespace_override: str):
if self.namespace != associated_actor["namespace"] or self.namespace_override != namespace_override:
raise ValueError("Runtime change of namespace not supported.")
self.distance_expected = distance
self.distance_traveled = 0.0
self.previous_x = 0
self.previous_y = 0
self.first_run = True

def _callback(self, msg):
'''
Expand All @@ -80,15 +91,6 @@ def calculate_distance(self, msg):
self.previous_x = msg.pose.pose.position.x
self.previous_y = msg.pose.pose.position.y

def initialise(self):
'''
Initialize before ticking.
'''
self.distance_traveled = 0.0
self.previous_x = 0
self.previous_y = 0
self.first_run = True

def update(self) -> py_trees.common.Status:
"""
Check if the traveled distance is reached
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ def setup(self, **kwargs):
)
self.feedback_message = f"Waiting for log" # pylint: disable= attribute-defined-outside-init

def execute(self, values: list, module_name: str):
self.module_name = module_name
self.values = values
self.found = None

def update(self) -> py_trees.common.Status:
"""
Wait for specified log entries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class for setting a node parameter
"""

def __init__(self, node_name: str, parameter_name: str, parameter_value: str):
self.node_name = node_name
self.parameter_name = parameter_name
self.parameter_value = parameter_value
service_name = node_name + '/set_parameters'
if not service_name.startswith('/'):
service_name = '/' + service_name
Expand Down Expand Up @@ -65,6 +68,10 @@ def __init__(self, node_name: str, parameter_name: str, parameter_value: str):
service_type='rcl_interfaces.srv.SetParameters',
data='{ "parameters": [{ "name": "' + parameter_name + '", "value": { "type": ' + str(parameter_type) + ', "' + parameter_assign_name + '": ' + parameter_value + '}}]}')

def execute(self, node_name: str, parameter_name: str, parameter_value: str): # pylint: disable=arguments-differ
if self.node_name != node_name or self.parameter_name != parameter_name or self.parameter_value != parameter_value:
raise ValueError("node_name, parameter_name and parameter_value are not changeable during runtime.")

@staticmethod
def is_float(element: any) -> bool:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,5 @@ import osc.ros
scenario test_ros_service_call:
timeout(30s)
do serial:
service_call() with:
keep(it.service_name == '/bla')
keep(it.service_type == 'std_srvs.srv.SetBool')
keep(it.data == '{\"data\": True}')
service_call('/bla', 'std_srvs.srv.SetBool', '{\"data\": True}')
emit end
19 changes: 19 additions & 0 deletions scenario_execution_ros/test/test_ros_log_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,25 @@ def test_success(self):
self.execute(scenario_content)
self.assertTrue(self.scenario_execution_ros.process_results())

def test_success_repeat(self):
scenario_content = """
import osc.ros
import osc.helpers
scenario test_success:
do parallel:
serial:
serial:
repeat(2)
log_check(values: ['ERROR'])
emit end
time_out: serial:
wait elapsed(10s)
emit fail
"""
self.execute(scenario_content)
self.assertTrue(self.scenario_execution_ros.process_results())

def test_timeout(self):
scenario_content = """
import osc.helpers
Expand Down
27 changes: 26 additions & 1 deletion scenario_execution_ros/test/test_ros_service_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from scenario_execution.model.osc2_parser import OpenScenario2Parser
from scenario_execution.utils.logging import Logger
from ament_index_python.packages import get_package_share_directory

from scenario_execution.model.model_to_py_tree import create_py_tree
from antlr4.InputStream import InputStream
import py_trees
from std_srvs.srv import SetBool

os.environ["PYTHONUNBUFFERED"] = '1'
Expand All @@ -46,6 +48,14 @@ def setUp(self):
self.srv = self.node.create_service(SetBool, "/bla", self.service_callback)
self.parser = OpenScenario2Parser(Logger('test', False))
self.scenario_execution_ros = ROSScenarioExecution()
self.tree = py_trees.composites.Sequence(name="", memory=True)

def execute(self, scenario_content):
parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content))
model = self.parser.create_internal_model(parsed_tree, self.tree, "test.osc", False)
self.tree = create_py_tree(model, self.tree, self.parser.logger, False)
self.scenario_execution_ros.tree = self.tree
self.scenario_execution_ros.run()

def tearDown(self):
self.node.destroy_node()
Expand All @@ -62,3 +72,18 @@ def test_success(self):
self.scenario_execution_ros.run()
self.assertTrue(self.scenario_execution_ros.process_results())
self.assertTrue(self.request_received)

def test_success_repeat(self):
scenario_content = """
import osc.helpers
import osc.ros
scenario test_ros_service_call:
timeout(30s)
do serial:
repeat(2)
service_call('/bla', 'std_srvs.srv.SetBool', '{\\\"data\\\": True}')
emit end
"""
self.execute(scenario_content)
self.assertTrue(self.scenario_execution_ros.process_results())

0 comments on commit 3306e82

Please sign in to comment.