Skip to content

Commit

Permalink
Fix service_call with repeat() modifier (#154)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Nikhil <[email protected]>
  • Loading branch information
fred-labs and Nikhil-Singhal-06 committed Aug 14, 2024
1 parent b8d7b4c commit 65835e9
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def setup(self, **kwargs):

def execute(self, node_name: str, state_sequence: list, allow_initial_skip: bool, fail_on_unexpected: bool, keep_running: bool):
if self.node_name != node_name or self.state_sequence != state_sequence:
raise ValueError(f"Updating node name or state_sequence not supported.")
raise ValueError("Runtime change of arguments 'name', 'state_sequence not supported.")

if all(isinstance(state, tuple) and len(state) == 2 for state in self.state_sequence):
self.state_sequence = [state[0] for state in self.state_sequence]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ def setup(self, **kwargs):
tf_static_topic=(tf_prefix + "/tf_static"),
)

def execute(self, frame_id: str, parent_frame_id: str, timeout: int, threshold_translation: float, threshold_rotation: float, wait_for_first_transform: bool, tf_topic_namespace: str, use_sim_time: bool):
if self.tf_topic_namespace != tf_topic_namespace:
raise ValueError("Runtime change of argument 'tf_topic_namespace' not supported.")
self.frame_id = frame_id
self.parent_frame_id = parent_frame_id
self.timeout = timeout
self.threshold_translation = threshold_translation
self.threshold_rotation = threshold_rotation
self.wait_for_first_transform = wait_for_first_transform
self.use_sim_time = use_sim_time

def update(self) -> py_trees.common.Status:
now = time.time()
transform = self.get_transform(self.frame_id, self.parent_frame_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def setup(self, **kwargs):
elif not success and not self.wait_for_first_message:
raise ValueError("Topic type must be specified. Please provide a valid topic type.")

def execute(self, topic_name: str, topic_type: str, latency: float, comparison_operator: bool, rolling_average_count: int, wait_for_first_message: bool):
if self.timer != 0:
raise ValueError("Action does not yet support to get retriggered")
self.timer = time.time()

def update(self) -> py_trees.common.Status:
Expand Down Expand Up @@ -122,13 +125,13 @@ def check_topic(self):

def call_subscriber(self):
datatype_in_list = self.topic_type.split(".")
self.topic_type = getattr(
topic_type = getattr(
importlib.import_module(".".join(datatype_in_list[:-1])),
datatype_in_list[-1]
)

self.subscription = self.node.create_subscription(
msg_type=self.topic_type,
msg_type=topic_type,
topic=self.topic_name,
callback=self._callback,
qos_profile=get_qos_preset_profile(['sensor_data']))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def __init__(self, associated_actor, distance: float, namespace_override: str):
self.node = None
self.subscriber = None
self.callback_group = None
if namespace_override:
self.namespace = namespace_override
self.namespace_override = namespace_override

def setup(self, **kwargs):
"""
Expand All @@ -52,8 +51,20 @@ def setup(self, **kwargs):
self.name, self.__class__.__name__)
raise KeyError(error_message) from e
self.callback_group = rclpy.callback_groups.MutuallyExclusiveCallbackGroup()
namespace = self.namespace
if self.namespace_override:
namespace = self.namespace_override
self.subscriber = self.node.create_subscription(
Odometry, self.namespace + '/odom', self._callback, 1000, callback_group=self.callback_group)
Odometry, namespace + '/odom', self._callback, 1000, callback_group=self.callback_group)

def execute(self, associated_actor, distance: float, namespace_override: str):
if self.namespace != associated_actor["namespace"] or self.namespace_override != namespace_override:
raise ValueError("Runtime change of namespace not supported.")
self.distance_expected = distance
self.distance_traveled = 0.0
self.previous_x = 0
self.previous_y = 0
self.first_run = True

def _callback(self, msg):
'''
Expand All @@ -80,15 +91,6 @@ def calculate_distance(self, msg):
self.previous_x = msg.pose.pose.position.x
self.previous_y = msg.pose.pose.position.y

def initialise(self):
'''
Initialize before ticking.
'''
self.distance_traveled = 0.0
self.previous_x = 0
self.previous_y = 0
self.first_run = True

def update(self) -> py_trees.common.Status:
"""
Check if the traveled distance is reached
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ def setup(self, **kwargs):
)
self.feedback_message = f"Waiting for log" # pylint: disable= attribute-defined-outside-init

def execute(self, values: list, module_name: str):
self.module_name = module_name
self.values = values
self.found = None

def update(self) -> py_trees.common.Status:
"""
Wait for specified log entries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ def __init__(self, service_name: str, service_type: str, data: str):
self.node = None
self.client = None
self.future = None
self.service_type = service_type
self.service_type_str = service_type
self.service_type = None
self.service_name = service_name
self.data_str = data
try:
trimmed_data = data.encode('utf-8').decode('unicode_escape')
trimmed_data = self.data_str.encode('utf-8').decode('unicode_escape')
self.data = literal_eval(trimmed_data)
except Exception as e: # pylint: disable=broad-except
raise ValueError(f"Error while parsing sevice call data:") from e
Expand All @@ -66,7 +68,7 @@ def setup(self, **kwargs):
self.name, self.__class__.__name__)
raise KeyError(error_message) from e

datatype_in_list = self.service_type.split(".")
datatype_in_list = self.service_type_str.split(".")
try:
self.service_type = getattr(
importlib.import_module(".".join(datatype_in_list[0:-1])),
Expand All @@ -77,6 +79,11 @@ def setup(self, **kwargs):
self.client = self.node.create_client(
self.service_type, self.service_name, callback_group=self.cb_group)

def execute(self, service_name: str, service_type: str, data: str):
if self.service_name != service_name or self.service_type_str != service_type or self.data_str != data:
raise ValueError("service_name, service_type and data arguments are not changeable during runtime.")
self.current_state = ServiceCallActionState.IDLE

def update(self) -> py_trees.common.Status:
"""
Execute states
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class for setting a node parameter
"""

def __init__(self, node_name: str, parameter_name: str, parameter_value: str):
self.node_name = node_name
self.parameter_name = parameter_name
self.parameter_value = parameter_value
service_name = node_name + '/set_parameters'
if not service_name.startswith('/'):
service_name = '/' + service_name
Expand Down Expand Up @@ -65,6 +68,10 @@ def __init__(self, node_name: str, parameter_name: str, parameter_value: str):
service_type='rcl_interfaces.srv.SetParameters',
data='{ "parameters": [{ "name": "' + parameter_name + '", "value": { "type": ' + str(parameter_type) + ', "' + parameter_assign_name + '": ' + parameter_value + '}}]}')

def execute(self, node_name: str, parameter_name: str, parameter_value: str): # pylint: disable=arguments-differ,arguments-renamed
if self.node_name != node_name or self.parameter_name != parameter_name or self.parameter_value != parameter_value:
raise ValueError("node_name, parameter_name and parameter_value are not changeable during runtime.")

@staticmethod
def is_float(element: any) -> bool:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,5 @@ import osc.ros
scenario test_ros_service_call:
timeout(30s)
do serial:
service_call() with:
keep(it.service_name == '/bla')
keep(it.service_type == 'std_srvs.srv.SetBool')
keep(it.data == '{\"data\": True}')
service_call('/bla', 'std_srvs.srv.SetBool', '{\"data\": True}')
emit end
19 changes: 19 additions & 0 deletions scenario_execution_ros/test/test_ros_log_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,25 @@ def test_success(self):
self.execute(scenario_content)
self.assertTrue(self.scenario_execution_ros.process_results())

def test_success_repeat(self):
scenario_content = """
import osc.ros
import osc.helpers
scenario test_success:
do parallel:
serial:
serial:
repeat(2)
log_check(values: ['ERROR'])
emit end
time_out: serial:
wait elapsed(10s)
emit fail
"""
self.execute(scenario_content)
self.assertTrue(self.scenario_execution_ros.process_results())

def test_timeout(self):
scenario_content = """
import osc.helpers
Expand Down
27 changes: 26 additions & 1 deletion scenario_execution_ros/test/test_ros_service_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
from scenario_execution_ros import ROSScenarioExecution
from scenario_execution.model.osc2_parser import OpenScenario2Parser
from scenario_execution.utils.logging import Logger
from scenario_execution.model.model_to_py_tree import create_py_tree
from ament_index_python.packages import get_package_share_directory

from antlr4.InputStream import InputStream
import py_trees
from std_srvs.srv import SetBool

os.environ["PYTHONUNBUFFERED"] = '1'
Expand All @@ -46,6 +48,14 @@ def setUp(self):
self.srv = self.node.create_service(SetBool, "/bla", self.service_callback)
self.parser = OpenScenario2Parser(Logger('test', False))
self.scenario_execution_ros = ROSScenarioExecution()
self.tree = py_trees.composites.Sequence(name="", memory=True)

def execute(self, scenario_content):
parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content))
model = self.parser.create_internal_model(parsed_tree, self.tree, "test.osc", False)
self.tree = create_py_tree(model, self.tree, self.parser.logger, False)
self.scenario_execution_ros.tree = self.tree
self.scenario_execution_ros.run()

def tearDown(self):
self.node.destroy_node()
Expand All @@ -62,3 +72,18 @@ def test_success(self):
self.scenario_execution_ros.run()
self.assertTrue(self.scenario_execution_ros.process_results())
self.assertTrue(self.request_received)

def test_success_repeat(self):
scenario_content = """
import osc.helpers
import osc.ros
scenario test_ros_service_call:
timeout(30s)
do serial:
repeat(2)
service_call('/bla', 'std_srvs.srv.SetBool', '{\\\"data\\\": True}')
emit end
"""
self.execute(scenario_content)
self.assertTrue(self.scenario_execution_ros.process_results())

0 comments on commit 65835e9

Please sign in to comment.