From 6593c502141ccf306bfce7cd2d353679c7084159 Mon Sep 17 00:00:00 2001 From: Nikhil Date: Mon, 12 Aug 2024 13:24:56 +0200 Subject: [PATCH 1/4] fix build warnings (#152) --- examples/example_multi_robot/setup.py | 1 + libs/scenario_execution_kubernetes/setup.py | 4 ++-- test/scenario_execution_gazebo_test/setup.py | 4 ++-- test/scenario_execution_nav2_test/setup.py | 4 ++-- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/example_multi_robot/setup.py b/examples/example_multi_robot/setup.py index 270b5c1b..d4d2203d 100644 --- a/examples/example_multi_robot/setup.py +++ b/examples/example_multi_robot/setup.py @@ -27,6 +27,7 @@ data_files=[ ('share/ament_index/resource_index/packages', ['resource/' + PACKAGE_NAME]), + ('share/' + PACKAGE_NAME, ['package.xml']), (os.path.join('share', PACKAGE_NAME, 'models'), glob('models/*.sdf')), (os.path.join('share', PACKAGE_NAME, 'launch'), glob('launch/*.py')), ], diff --git a/libs/scenario_execution_kubernetes/setup.py b/libs/scenario_execution_kubernetes/setup.py index 84371d0e..a75043d7 100644 --- a/libs/scenario_execution_kubernetes/setup.py +++ b/libs/scenario_execution_kubernetes/setup.py @@ -16,14 +16,14 @@ from glob import glob import os -from setuptools import find_packages, setup +from setuptools import find_namespace_packages, setup PACKAGE_NAME = 'scenario_execution_kubernetes' setup( name=PACKAGE_NAME, version='1.1.0', - packages=find_packages(), + packages=find_namespace_packages(), data_files=[ ('share/ament_index/resource_index/packages', ['resource/' + PACKAGE_NAME]), diff --git a/test/scenario_execution_gazebo_test/setup.py b/test/scenario_execution_gazebo_test/setup.py index 04483aaa..10bd92f9 100644 --- a/test/scenario_execution_gazebo_test/setup.py +++ b/test/scenario_execution_gazebo_test/setup.py @@ -17,14 +17,14 @@ """ Setup python package """ from glob import glob import os -from setuptools import find_namespace_packages, setup +from setuptools import find_packages, setup PACKAGE_NAME = 'scenario_execution_gazebo_test' setup( name=PACKAGE_NAME, version='1.2.0', - packages=find_namespace_packages(), + packages=find_packages(), data_files=[ ('share/ament_index/resource_index/packages', ['resource/' + PACKAGE_NAME]), diff --git a/test/scenario_execution_nav2_test/setup.py b/test/scenario_execution_nav2_test/setup.py index 0bed80a3..3208ed9e 100644 --- a/test/scenario_execution_nav2_test/setup.py +++ b/test/scenario_execution_nav2_test/setup.py @@ -17,14 +17,14 @@ """ Setup python package """ from glob import glob import os -from setuptools import find_namespace_packages, setup +from setuptools import find_packages, setup PACKAGE_NAME = 'scenario_execution_nav2_test' setup( name=PACKAGE_NAME, version='1.2.0', - packages=find_namespace_packages(), + packages=find_packages(), data_files=[ ('share/ament_index/resource_index/packages', ['resource/' + PACKAGE_NAME]), From b8d7b4c3699ca3828d05d47b6faf8f0ff11b4e08 Mon Sep 17 00:00:00 2001 From: fred-labs Date: Mon, 12 Aug 2024 17:22:31 +0200 Subject: [PATCH 2/4] add support for more kubernetes functionality (#153) --- docs/libraries.rst | 64 ++++++++++++ .../kubernetes_base_action.py | 78 +++++++++++++++ .../kubernetes_create_from_yaml.py | 8 +- .../kubernetes_patch_pod.py | 39 ++++++++ .../kubernetes_pod_exec.py | 99 +++++++++++++++++++ .../lib_osc/kubernetes.osc | 14 ++- .../test_kubernetes_create_delete.osc | 2 +- .../scenarios/test_kubernetes_pod_exec.osc | 14 +++ libs/scenario_execution_kubernetes/setup.py | 2 + 9 files changed, 315 insertions(+), 5 deletions(-) create mode 100644 libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_base_action.py create mode 100644 libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_patch_pod.py create mode 100644 libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_pod_exec.py create mode 100644 libs/scenario_execution_kubernetes/scenarios/test_kubernetes_pod_exec.osc diff --git a/docs/libraries.rst b/docs/libraries.rst index a3f15014..abd1c91f 100644 --- a/docs/libraries.rst +++ b/docs/libraries.rst @@ -447,6 +447,70 @@ Patch an existing Kubernetes network policy. - key-value pair to match (e.g., ``key_value("app", "pod_name"))`` +``kubernetes_patch_pod()`` +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Patch an existing pod. If patching resources, please check `feature gates `__ + +.. list-table:: + :widths: 15 15 5 65 + :header-rows: 1 + :class: tight-table + + * - Parameter + - Type + - Default + - Description + * - ``namespace`` + - ``string`` + - ``default`` + - Kubernetes namespace + * - ``within_cluster`` + - ``bool`` + - ``false`` + - set to true if you want to access the cluster from within a running container/pod + * - ``target`` + - ``string`` + - + - The target pod to patch + * - ``body`` + - ``string`` + - + - Patch to apply. Example: ``'{\"spec\":{\"containers\":[{\"name\":\"main\", \"resources\":{\"requests\":{\"cpu\":\"200m\"}, \"limits\":{\"cpu\":\"200m\"}}}]}}'`` + + +``kubernetes_pod_exec()`` +^^^^^^^^^^^^^^^^^^^^^^^^^ + +Execute a command within a running pod + +.. list-table:: + :widths: 15 15 5 65 + :header-rows: 1 + :class: tight-table + + * - Parameter + - Type + - Default + - Description + * - ``namespace`` + - ``string`` + - ``default`` + - Kubernetes namespace + * - ``within_cluster`` + - ``bool`` + - ``false`` + - set to true if you want to access the cluster from within a running container/pod + * - ``target`` + - ``string`` + - + - The target pod to execute the command in + * - ``command`` + - ``list of string`` + - + - Command to execute + + ``kubernetes_wait_for_network_policy_status()`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_base_action.py b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_base_action.py new file mode 100644 index 00000000..6c45509d --- /dev/null +++ b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_base_action.py @@ -0,0 +1,78 @@ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from kubernetes import client, config +from enum import Enum +import py_trees +import json +from scenario_execution.actions.base_action import BaseAction + + +class KubernetesBaseActionState(Enum): + IDLE = 1 + REQUEST_SENT = 2 + FAILURE = 3 + + +class KubernetesBaseAction(BaseAction): + + def __init__(self, namespace: str, within_cluster: bool): + super().__init__() + self.namespace = namespace + self.within_cluster = within_cluster + self.client = None + self.current_state = KubernetesBaseActionState.IDLE + self.current_request = None + + def setup(self, **kwargs): + if self.within_cluster: + config.load_incluster_config() + else: + config.load_kube_config() + self.client = client.CoreV1Api() + + def execute(self, namespace: str, within_cluster: bool): + self.namespace = namespace + if within_cluster != self.within_cluster: + raise ValueError("parameter 'within_cluster' is not allowed to change since initialization.") + + def update(self) -> py_trees.common.Status: # pylint: disable=too-many-return-statements + if self.current_state == KubernetesBaseActionState.IDLE: + self.current_request = self.kubernetes_call() + self.current_state = KubernetesBaseActionState.REQUEST_SENT + return py_trees.common.Status.RUNNING + elif self.current_state == KubernetesBaseActionState.REQUEST_SENT: + success = True + if self.current_request.ready(): + if not self.current_request.successful(): + try: + self.current_request.get() + except client.exceptions.ApiException as e: + message = "" + body = json.loads(e.body) + if "message" in body: + message = f", message: '{body['message']}'" + self.feedback_message = f"Failure! Reason: {e.reason} {message}" # pylint: disable= attribute-defined-outside-init + success = False + if success: + return py_trees.common.Status.SUCCESS + else: + return py_trees.common.Status.FAILURE + return py_trees.common.Status.FAILURE + + def kubernetes_call(self): + # Use async_req = True, namespace=self.namespace + raise NotImplementedError("Implement in derived action") diff --git a/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_create_from_yaml.py b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_create_from_yaml.py index 8ea32f81..9c91a20f 100644 --- a/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_create_from_yaml.py +++ b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_create_from_yaml.py @@ -47,8 +47,12 @@ def setup(self, **kwargs): def update(self) -> py_trees.common.Status: # pylint: disable=too-many-return-statements if self.current_state == KubernetesCreateFromYamlActionState.IDLE: - self.current_request = utils.create_from_yaml( - self.client, self.yaml_file, verbose=False, namespace=self.namespace, async_req=True) + try: + self.current_request = utils.create_from_yaml( + self.client, self.yaml_file, verbose=False, namespace=self.namespace, async_req=True) + except Exception as e: # pylint: disable=broad-except + self.feedback_message = f"Error while creating from yaml: {e}" + return py_trees.common.Status.FAILURE self.current_state = KubernetesCreateFromYamlActionState.CREATION_REQUESTED self.feedback_message = f"Requested creation from yaml file '{self.yaml_file}' in namespace '{self.namespace}'" # pylint: disable= attribute-defined-outside-init return py_trees.common.Status.RUNNING diff --git a/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_patch_pod.py b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_patch_pod.py new file mode 100644 index 00000000..33ee2a84 --- /dev/null +++ b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_patch_pod.py @@ -0,0 +1,39 @@ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from ast import literal_eval +from .kubernetes_base_action import KubernetesBaseAction + + +class KubernetesPatchPod(KubernetesBaseAction): + + def __init__(self, namespace: str, target: str, body: str, within_cluster: bool): + super().__init__(namespace, within_cluster) + self.target = target + self.body = None + + def execute(self, namespace: str, target: str, body: str, within_cluster: bool): # pylint: disable=arguments-differ + super().execute(namespace, within_cluster) + self.target = target + trimmed_data = body.encode('utf-8').decode('unicode_escape') + try: + self.body = literal_eval(trimmed_data) + except ValueError as e: + raise ValueError(f"Could not parse body '{trimmed_data}': {e}") from e + + def kubernetes_call(self): + self.feedback_message = f"Requested patching '{self.target}' in namespace '{self.namespace}'" # pylint: disable= attribute-defined-outside-init + return self.client.patch_namespaced_pod(self.target, body=self.body, namespace=self.namespace, async_req=True) diff --git a/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_pod_exec.py b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_pod_exec.py new file mode 100644 index 00000000..008c90c7 --- /dev/null +++ b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/kubernetes_pod_exec.py @@ -0,0 +1,99 @@ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import py_trees +from scenario_execution.actions.base_action import BaseAction +import queue +import threading +from kubernetes import client, config, stream +from enum import Enum + + +class KubernetesPodExecState(Enum): + IDLE = 1 + RUNNING = 2 + FAILURE = 3 + + +class KubernetesPodExec(BaseAction): + + def __init__(self, target: str, command: list, namespace: str, within_cluster: bool): + super().__init__() + self.target = target + self.namespace = namespace + self.command = command + self.within_cluster = within_cluster + self.client = None + self.reponse_queue = queue.Queue() + self.current_state = KubernetesPodExecState.IDLE + self.output_queue = queue.Queue() + + def setup(self, **kwargs): + if self.within_cluster: + config.load_incluster_config() + else: + config.load_kube_config() + self.client = client.CoreV1Api() + + self.exec_thread = threading.Thread(target=self.pod_exec, daemon=True) + + def execute(self, target: str, command: list, namespace: str, within_cluster: bool): + if within_cluster != self.within_cluster: + raise ValueError("parameter 'within_cluster' is not allowed to change since initialization.") + self.target = target + self.namespace = namespace + self.command = command + self.current_state = KubernetesPodExecState.IDLE + + def update(self) -> py_trees.common.Status: + if self.current_state == KubernetesPodExecState.IDLE: + self.current_state = KubernetesPodExecState.RUNNING + self.feedback_message = f"Executing on pod '{self.target}': {self.command}..." # pylint: disable= attribute-defined-outside-init + self.exec_thread.start() + return py_trees.common.Status.RUNNING + elif self.current_state == KubernetesPodExecState.RUNNING: + while not self.output_queue.empty(): + self.logger.debug(self.output_queue.get()) + try: + response = self.reponse_queue.get_nowait() + try: + if response.returncode == 0: + self.feedback_message = f"Execution successful." # pylint: disable= attribute-defined-outside-init + return py_trees.common.Status.SUCCESS + except ValueError: + self.feedback_message = f"Error while executing." # pylint: disable= attribute-defined-outside-init + except queue.Empty: + return py_trees.common.Status.RUNNING + + return py_trees.common.Status.FAILURE + + def pod_exec(self): + resp = stream.stream(self.client.connect_get_namespaced_pod_exec, + self.target, + self.namespace, + command=self.command, + stderr=True, stdin=False, + stdout=True, tty=False, + _preload_content=False) + + while resp.is_open(): + resp.update(timeout=0.1) + if resp.peek_stdout(): + self.output_queue.put(resp.read_stdout()) + if resp.peek_stderr(): + self.output_queue.put(resp.read_stderr()) + + self.reponse_queue.put(resp) diff --git a/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/lib_osc/kubernetes.osc b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/lib_osc/kubernetes.osc index 14910244..47cebcb4 100644 --- a/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/lib_osc/kubernetes.osc +++ b/libs/scenario_execution_kubernetes/scenario_execution_kubernetes/lib_osc/kubernetes.osc @@ -35,9 +35,19 @@ action kubernetes_delete inherits kubernetes_base_action: action kubernetes_patch_network_policy inherits kubernetes_base_action: # patch an existing network policy - target: string # network-policy name to monitor + target: string # network-policy to patch network_enabled: bool # should the network be enabled? - match_label: key_value # key-value pair to match + match_label: key_value + +action kubernetes_patch_pod inherits kubernetes_base_action: + # patch an existing pod. If patching resources, please check feature gates: https://kubernetes.io/docs/tasks/configure-pod-container/resize-container-resources/#container-resize-policies + target: string # pod to patch + body: string # patch to apply + +action kubernetes_pod_exec inherits kubernetes_base_action: + # execute a command within a running pod + target: string # pod to patch + command: list of string # command to execute action kubernetes_wait_for_network_policy_status inherits kubernetes_base_action: # wait for a network-policy to reach the specified state diff --git a/libs/scenario_execution_kubernetes/scenarios/test_kubernetes_create_delete.osc b/libs/scenario_execution_kubernetes/scenarios/test_kubernetes_create_delete.osc index 81a558cb..7f98de05 100644 --- a/libs/scenario_execution_kubernetes/scenarios/test_kubernetes_create_delete.osc +++ b/libs/scenario_execution_kubernetes/scenarios/test_kubernetes_create_delete.osc @@ -3,7 +3,7 @@ import osc.kubernetes import osc.helpers scenario test_kubernetes_create_from_yaml: - timeout(30s) + timeout(60s) do serial: kubernetes_create_from_yaml(yaml_file: "test.yaml") kubernetes_wait_for_pod_status(target: "test", status: kubernetes_pod_status!running) diff --git a/libs/scenario_execution_kubernetes/scenarios/test_kubernetes_pod_exec.osc b/libs/scenario_execution_kubernetes/scenarios/test_kubernetes_pod_exec.osc new file mode 100644 index 00000000..786b7236 --- /dev/null +++ b/libs/scenario_execution_kubernetes/scenarios/test_kubernetes_pod_exec.osc @@ -0,0 +1,14 @@ +import osc.standard.base +import osc.kubernetes +import osc.helpers + +scenario test_kubernetes_create_from_yaml: + timeout(60s) + do serial: + kubernetes_create_from_yaml(yaml_file: "test.yaml") + kubernetes_wait_for_pod_status(target: "test", status: kubernetes_pod_status!running) + kubernetes_patch_pod(target: "test", body: '{\"spec\":{\"containers\":[{\"name\":\"main\", \"resources\":{\"requests\":{\"cpu\":\"200m\"}, \"limits\":{\"cpu\":\"200m\"}}}]}}') + kubernetes_pod_exec(target: "test", command: ['sysbench', 'cpu', 'run']) + kubernetes_patch_pod(target: "test", body: '{\"spec\":{\"containers\":[{\"name\":\"main\", \"resources\":{\"requests\":{\"cpu\":\"800m\"}, \"limits\":{\"cpu\":\"800m\"}}}]}}') + kubernetes_pod_exec(target: "test", command: ['sysbench', 'cpu', 'run']) + kubernetes_delete(target: "test", element_type: kubernetes_element_type!pod) diff --git a/libs/scenario_execution_kubernetes/setup.py b/libs/scenario_execution_kubernetes/setup.py index a75043d7..9f23de7e 100644 --- a/libs/scenario_execution_kubernetes/setup.py +++ b/libs/scenario_execution_kubernetes/setup.py @@ -44,6 +44,8 @@ 'kubernetes_create_from_yaml = scenario_execution_kubernetes.kubernetes_create_from_yaml:KubernetesCreateFromYaml', 'kubernetes_delete = scenario_execution_kubernetes.kubernetes_delete:KubernetesDelete', 'kubernetes_patch_network_policy = scenario_execution_kubernetes.kubernetes_patch_network_policy:KubernetesPatchNetworkPolicy', + 'kubernetes_patch_pod = scenario_execution_kubernetes.kubernetes_patch_pod:KubernetesPatchPod', + 'kubernetes_pod_exec = scenario_execution_kubernetes.kubernetes_pod_exec:KubernetesPodExec', 'kubernetes_wait_for_network_policy_status = scenario_execution_kubernetes.kubernetes_wait_for_network_policy_status:KubernetesWaitForNetworkPolicyStatus', 'kubernetes_wait_for_pod_status = scenario_execution_kubernetes.kubernetes_wait_for_pod_status:KubernetesWaitForPodStatus', ], From 65835e9caec57d9d453fa7b40f151abdfc91742c Mon Sep 17 00:00:00 2001 From: fred-labs Date: Wed, 14 Aug 2024 09:58:32 +0200 Subject: [PATCH 3/4] Fix service_call with repeat() modifier (#154) --------- Co-authored-by: Nikhil --- .../actions/assert_lifecycle_state.py | 2 +- .../actions/assert_tf_moving.py | 11 ++++++++ .../actions/assert_topic_latency.py | 7 +++-- .../actions/odometry_distance_traveled.py | 26 +++++++++--------- .../actions/ros_log_check.py | 5 ++++ .../actions/ros_service_call.py | 13 ++++++--- .../actions/ros_set_node_parameter.py | 7 +++++ .../scenarios/test/test_ros_service_call.osc | 5 +--- .../test/test_ros_log_check.py | 19 +++++++++++++ .../test/test_ros_service_call.py | 27 ++++++++++++++++++- 10 files changed, 99 insertions(+), 23 deletions(-) diff --git a/scenario_execution_ros/scenario_execution_ros/actions/assert_lifecycle_state.py b/scenario_execution_ros/scenario_execution_ros/actions/assert_lifecycle_state.py index 58c9b598..7138b38a 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/assert_lifecycle_state.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/assert_lifecycle_state.py @@ -67,7 +67,7 @@ def setup(self, **kwargs): def execute(self, node_name: str, state_sequence: list, allow_initial_skip: bool, fail_on_unexpected: bool, keep_running: bool): if self.node_name != node_name or self.state_sequence != state_sequence: - raise ValueError(f"Updating node name or state_sequence not supported.") + raise ValueError("Runtime change of arguments 'name', 'state_sequence not supported.") if all(isinstance(state, tuple) and len(state) == 2 for state in self.state_sequence): self.state_sequence = [state[0] for state in self.state_sequence] diff --git a/scenario_execution_ros/scenario_execution_ros/actions/assert_tf_moving.py b/scenario_execution_ros/scenario_execution_ros/actions/assert_tf_moving.py index ad6f2052..0e13c518 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/assert_tf_moving.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/assert_tf_moving.py @@ -65,6 +65,17 @@ def setup(self, **kwargs): tf_static_topic=(tf_prefix + "/tf_static"), ) + def execute(self, frame_id: str, parent_frame_id: str, timeout: int, threshold_translation: float, threshold_rotation: float, wait_for_first_transform: bool, tf_topic_namespace: str, use_sim_time: bool): + if self.tf_topic_namespace != tf_topic_namespace: + raise ValueError("Runtime change of argument 'tf_topic_namespace' not supported.") + self.frame_id = frame_id + self.parent_frame_id = parent_frame_id + self.timeout = timeout + self.threshold_translation = threshold_translation + self.threshold_rotation = threshold_rotation + self.wait_for_first_transform = wait_for_first_transform + self.use_sim_time = use_sim_time + def update(self) -> py_trees.common.Status: now = time.time() transform = self.get_transform(self.frame_id, self.parent_frame_id) diff --git a/scenario_execution_ros/scenario_execution_ros/actions/assert_topic_latency.py b/scenario_execution_ros/scenario_execution_ros/actions/assert_topic_latency.py index a6a503a0..601fade6 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/assert_topic_latency.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/assert_topic_latency.py @@ -59,6 +59,9 @@ def setup(self, **kwargs): elif not success and not self.wait_for_first_message: raise ValueError("Topic type must be specified. Please provide a valid topic type.") + def execute(self, topic_name: str, topic_type: str, latency: float, comparison_operator: bool, rolling_average_count: int, wait_for_first_message: bool): + if self.timer != 0: + raise ValueError("Action does not yet support to get retriggered") self.timer = time.time() def update(self) -> py_trees.common.Status: @@ -122,13 +125,13 @@ def check_topic(self): def call_subscriber(self): datatype_in_list = self.topic_type.split(".") - self.topic_type = getattr( + topic_type = getattr( importlib.import_module(".".join(datatype_in_list[:-1])), datatype_in_list[-1] ) self.subscription = self.node.create_subscription( - msg_type=self.topic_type, + msg_type=topic_type, topic=self.topic_name, callback=self._callback, qos_profile=get_qos_preset_profile(['sensor_data'])) diff --git a/scenario_execution_ros/scenario_execution_ros/actions/odometry_distance_traveled.py b/scenario_execution_ros/scenario_execution_ros/actions/odometry_distance_traveled.py index bfebd328..d821e067 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/odometry_distance_traveled.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/odometry_distance_traveled.py @@ -38,8 +38,7 @@ def __init__(self, associated_actor, distance: float, namespace_override: str): self.node = None self.subscriber = None self.callback_group = None - if namespace_override: - self.namespace = namespace_override + self.namespace_override = namespace_override def setup(self, **kwargs): """ @@ -52,8 +51,20 @@ def setup(self, **kwargs): self.name, self.__class__.__name__) raise KeyError(error_message) from e self.callback_group = rclpy.callback_groups.MutuallyExclusiveCallbackGroup() + namespace = self.namespace + if self.namespace_override: + namespace = self.namespace_override self.subscriber = self.node.create_subscription( - Odometry, self.namespace + '/odom', self._callback, 1000, callback_group=self.callback_group) + Odometry, namespace + '/odom', self._callback, 1000, callback_group=self.callback_group) + + def execute(self, associated_actor, distance: float, namespace_override: str): + if self.namespace != associated_actor["namespace"] or self.namespace_override != namespace_override: + raise ValueError("Runtime change of namespace not supported.") + self.distance_expected = distance + self.distance_traveled = 0.0 + self.previous_x = 0 + self.previous_y = 0 + self.first_run = True def _callback(self, msg): ''' @@ -80,15 +91,6 @@ def calculate_distance(self, msg): self.previous_x = msg.pose.pose.position.x self.previous_y = msg.pose.pose.position.y - def initialise(self): - ''' - Initialize before ticking. - ''' - self.distance_traveled = 0.0 - self.previous_x = 0 - self.previous_y = 0 - self.first_run = True - def update(self) -> py_trees.common.Status: """ Check if the traveled distance is reached diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_log_check.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_log_check.py index 66460735..617b6a16 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_log_check.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_log_check.py @@ -64,6 +64,11 @@ def setup(self, **kwargs): ) self.feedback_message = f"Waiting for log" # pylint: disable= attribute-defined-outside-init + def execute(self, values: list, module_name: str): + self.module_name = module_name + self.values = values + self.found = None + def update(self) -> py_trees.common.Status: """ Wait for specified log entries diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_service_call.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_service_call.py index 62c1bd6e..f22e6428 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_service_call.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_service_call.py @@ -44,10 +44,12 @@ def __init__(self, service_name: str, service_type: str, data: str): self.node = None self.client = None self.future = None - self.service_type = service_type + self.service_type_str = service_type + self.service_type = None self.service_name = service_name + self.data_str = data try: - trimmed_data = data.encode('utf-8').decode('unicode_escape') + trimmed_data = self.data_str.encode('utf-8').decode('unicode_escape') self.data = literal_eval(trimmed_data) except Exception as e: # pylint: disable=broad-except raise ValueError(f"Error while parsing sevice call data:") from e @@ -66,7 +68,7 @@ def setup(self, **kwargs): self.name, self.__class__.__name__) raise KeyError(error_message) from e - datatype_in_list = self.service_type.split(".") + datatype_in_list = self.service_type_str.split(".") try: self.service_type = getattr( importlib.import_module(".".join(datatype_in_list[0:-1])), @@ -77,6 +79,11 @@ def setup(self, **kwargs): self.client = self.node.create_client( self.service_type, self.service_name, callback_group=self.cb_group) + def execute(self, service_name: str, service_type: str, data: str): + if self.service_name != service_name or self.service_type_str != service_type or self.data_str != data: + raise ValueError("service_name, service_type and data arguments are not changeable during runtime.") + self.current_state = ServiceCallActionState.IDLE + def update(self) -> py_trees.common.Status: """ Execute states diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_set_node_parameter.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_set_node_parameter.py index 66e33faf..3464ad7c 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_set_node_parameter.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_set_node_parameter.py @@ -26,6 +26,9 @@ class for setting a node parameter """ def __init__(self, node_name: str, parameter_name: str, parameter_value: str): + self.node_name = node_name + self.parameter_name = parameter_name + self.parameter_value = parameter_value service_name = node_name + '/set_parameters' if not service_name.startswith('/'): service_name = '/' + service_name @@ -65,6 +68,10 @@ def __init__(self, node_name: str, parameter_name: str, parameter_value: str): service_type='rcl_interfaces.srv.SetParameters', data='{ "parameters": [{ "name": "' + parameter_name + '", "value": { "type": ' + str(parameter_type) + ', "' + parameter_assign_name + '": ' + parameter_value + '}}]}') + def execute(self, node_name: str, parameter_name: str, parameter_value: str): # pylint: disable=arguments-differ,arguments-renamed + if self.node_name != node_name or self.parameter_name != parameter_name or self.parameter_value != parameter_value: + raise ValueError("node_name, parameter_name and parameter_value are not changeable during runtime.") + @staticmethod def is_float(element: any) -> bool: """ diff --git a/scenario_execution_ros/scenarios/test/test_ros_service_call.osc b/scenario_execution_ros/scenarios/test/test_ros_service_call.osc index e364f45e..9825a3c7 100644 --- a/scenario_execution_ros/scenarios/test/test_ros_service_call.osc +++ b/scenario_execution_ros/scenarios/test/test_ros_service_call.osc @@ -4,8 +4,5 @@ import osc.ros scenario test_ros_service_call: timeout(30s) do serial: - service_call() with: - keep(it.service_name == '/bla') - keep(it.service_type == 'std_srvs.srv.SetBool') - keep(it.data == '{\"data\": True}') + service_call('/bla', 'std_srvs.srv.SetBool', '{\"data\": True}') emit end diff --git a/scenario_execution_ros/test/test_ros_log_check.py b/scenario_execution_ros/test/test_ros_log_check.py index 1649f99c..c06081ee 100644 --- a/scenario_execution_ros/test/test_ros_log_check.py +++ b/scenario_execution_ros/test/test_ros_log_check.py @@ -81,6 +81,25 @@ def test_success(self): self.execute(scenario_content) self.assertTrue(self.scenario_execution_ros.process_results()) + def test_success_repeat(self): + scenario_content = """ +import osc.ros +import osc.helpers + +scenario test_success: + do parallel: + serial: + serial: + repeat(2) + log_check(values: ['ERROR']) + emit end + time_out: serial: + wait elapsed(10s) + emit fail +""" + self.execute(scenario_content) + self.assertTrue(self.scenario_execution_ros.process_results()) + def test_timeout(self): scenario_content = """ import osc.helpers diff --git a/scenario_execution_ros/test/test_ros_service_call.py b/scenario_execution_ros/test/test_ros_service_call.py index ffa65cba..ad9916e9 100644 --- a/scenario_execution_ros/test/test_ros_service_call.py +++ b/scenario_execution_ros/test/test_ros_service_call.py @@ -21,8 +21,10 @@ from scenario_execution_ros import ROSScenarioExecution from scenario_execution.model.osc2_parser import OpenScenario2Parser from scenario_execution.utils.logging import Logger +from scenario_execution.model.model_to_py_tree import create_py_tree from ament_index_python.packages import get_package_share_directory - +from antlr4.InputStream import InputStream +import py_trees from std_srvs.srv import SetBool os.environ["PYTHONUNBUFFERED"] = '1' @@ -46,6 +48,14 @@ def setUp(self): self.srv = self.node.create_service(SetBool, "/bla", self.service_callback) self.parser = OpenScenario2Parser(Logger('test', False)) self.scenario_execution_ros = ROSScenarioExecution() + self.tree = py_trees.composites.Sequence(name="", memory=True) + + def execute(self, scenario_content): + parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) + model = self.parser.create_internal_model(parsed_tree, self.tree, "test.osc", False) + self.tree = create_py_tree(model, self.tree, self.parser.logger, False) + self.scenario_execution_ros.tree = self.tree + self.scenario_execution_ros.run() def tearDown(self): self.node.destroy_node() @@ -62,3 +72,18 @@ def test_success(self): self.scenario_execution_ros.run() self.assertTrue(self.scenario_execution_ros.process_results()) self.assertTrue(self.request_received) + + def test_success_repeat(self): + scenario_content = """ +import osc.helpers +import osc.ros + +scenario test_ros_service_call: + timeout(30s) + do serial: + repeat(2) + service_call('/bla', 'std_srvs.srv.SetBool', '{\\\"data\\\": True}') + emit end +""" + self.execute(scenario_content) + self.assertTrue(self.scenario_execution_ros.process_results()) From 3de474812e26620a842c85e024483fab626766b1 Mon Sep 17 00:00:00 2001 From: fred-labs Date: Wed, 14 Aug 2024 10:41:39 +0200 Subject: [PATCH 4/4] Add support for expressions (#157) --- .github/workflows/test_build.yml | 2 +- docs/openscenario2.rst | 2 +- .../scenario_coverage/scenario_variation.py | 4 +- .../model/model_to_py_tree.py | 40 ++++- .../scenario_execution/model/types.py | 167 +++++++++++++++--- scenario_execution/test/test_expression.py | 87 ++++----- .../actions/ros_topic_check_data.py | 2 +- .../actions/ros_topic_monitor.py | 26 ++- .../scenario_execution_ros/lib_osc/ros.osc | 1 + .../test/test_topic_monitor.py | 81 +++++++++ .../actions/set_blackboard_var.py | 31 ++++ test/scenario_execution_test/setup.py | 1 + .../test/test_expression_with_var.py | 98 ++++++++++ 13 files changed, 463 insertions(+), 79 deletions(-) create mode 100644 test/scenario_execution_test/scenario_execution_test/actions/set_blackboard_var.py create mode 100644 test/scenario_execution_test/test/test_expression_with_var.py diff --git a/.github/workflows/test_build.yml b/.github/workflows/test_build.yml index 26f7f571..e84291f8 100644 --- a/.github/workflows/test_build.yml +++ b/.github/workflows/test_build.yml @@ -417,7 +417,7 @@ jobs: export ROS_DOMAIN_ID=2 export IGN_PARTITION=${HOSTNAME}:${GITHUB_RUN_ID} # shellcheck disable=SC1083 - scenario_batch_execution -i test/scenario_execution_nav2_test/scenarios/ -o test_scenario_execution_nav2 -- ros2 launch tb4_sim_scenario sim_nav_scenario_launch.py scenario:={SCENARIO} output_dir:={OUTPUT_DIR} headless:=True + scenario_batch_execution -i test/scenario_execution_nav2_test/scenarios/ -o test_scenario_execution_nav2 --ignore-process-return-value -- ros2 launch tb4_sim_scenario sim_nav_scenario_launch.py scenario:={SCENARIO} output_dir:={OUTPUT_DIR} headless:=True - name: Upload result uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4.3.4 if: always() diff --git a/docs/openscenario2.rst b/docs/openscenario2.rst index cecf9e7c..310b4862 100644 --- a/docs/openscenario2.rst +++ b/docs/openscenario2.rst @@ -66,7 +66,7 @@ Element Tag Support Notes ``enum`` :raw-html:`✅` ``event`` :raw-html:`✅` ``every`` :raw-html:`❌` -``expression`` :raw-html:`❌` +``expression`` :raw-html:`✅` ``extend`` :raw-html:`❌` ``external`` :raw-html:`❌` ``fall`` :raw-html:`❌` diff --git a/scenario_coverage/scenario_coverage/scenario_variation.py b/scenario_coverage/scenario_coverage/scenario_variation.py index 7a03dd91..83c205a3 100644 --- a/scenario_coverage/scenario_coverage/scenario_variation.py +++ b/scenario_coverage/scenario_coverage/scenario_variation.py @@ -25,7 +25,7 @@ import py_trees from scenario_execution.model.osc2_parser import OpenScenario2Parser from scenario_execution.model.model_resolver import resolve_internal_model -from scenario_execution.model.types import RelationExpression, ListExpression, FieldAccessExpression, Expression, print_tree, serialize, to_string +from scenario_execution.model.types import RelationExpression, ListExpression, FieldAccessExpression, ModelExpression, print_tree, serialize, to_string from scenario_execution.utils.logging import Logger @@ -138,7 +138,7 @@ def save_resulting_scenarios(self, models): # create description variation_descriptions = [] for descr, entry in model[1]: - if isinstance(entry, Expression): + if isinstance(entry, ModelExpression): val = None for child in entry.get_children(): if not isinstance(child, FieldAccessExpression): diff --git a/scenario_execution/scenario_execution/model/model_to_py_tree.py b/scenario_execution/scenario_execution/model/model_to_py_tree.py index cbe394e8..f5a0074a 100644 --- a/scenario_execution/scenario_execution/model/model_to_py_tree.py +++ b/scenario_execution/scenario_execution/model/model_to_py_tree.py @@ -18,10 +18,9 @@ import py_trees from py_trees.common import Access, Status from pkg_resources import iter_entry_points - import inspect -from scenario_execution.model.types import ActionDeclaration, EventReference, FunctionApplicationExpression, ModifierInvocation, ScenarioDeclaration, DoMember, WaitDirective, EmitDirective, BehaviorInvocation, EventCondition, EventDeclaration, RelationExpression, LogicalExpression, ElapsedExpression, PhysicalLiteral, ModifierDeclaration +from scenario_execution.model.types import KeepConstraintDeclaration, visit_expression, ActionDeclaration, BinaryExpression, EventReference, Expression, FunctionApplicationExpression, ModifierInvocation, ScenarioDeclaration, DoMember, WaitDirective, EmitDirective, BehaviorInvocation, EventCondition, EventDeclaration, RelationExpression, LogicalExpression, ElapsedExpression, PhysicalLiteral, ModifierDeclaration from scenario_execution.model.model_base_visitor import ModelBaseVisitor from scenario_execution.model.error import OSC2ParsingError from scenario_execution.actions.base_action import BaseAction @@ -103,6 +102,20 @@ def update(self): return Status.SUCCESS +class ExpressionBehavior(py_trees.behaviour.Behaviour): + + def __init__(self, name: "ExpressionBehavior", expression: Expression): + super().__init__(name) + + self.expression = expression + + def update(self): + if self.expression.eval(): + return Status.SUCCESS + else: + return Status.RUNNING + + class ModelToPyTree(object): def __init__(self, logger): @@ -122,6 +135,7 @@ class BehaviorInit(ModelBaseVisitor): def __init__(self, logger, tree) -> None: super().__init__() self.logger = logger + self.blackboard = None if not isinstance(tree, py_trees.composites.Sequence): raise ValueError("ModelToPyTree requires a py-tree sequence as input") self.tree = tree @@ -348,19 +362,25 @@ def visit_event_reference(self, node: EventReference): def visit_event_condition(self, node: EventCondition): expression = "" for child in node.get_children(): - if isinstance(child, RelationExpression): - raise NotImplementedError() - elif isinstance(child, LogicalExpression): - raise NotImplementedError() + if isinstance(child, (RelationExpression, LogicalExpression)): + expression = ExpressionBehavior(name=node.get_ctx()[2], expression=self.visit(child)) elif isinstance(child, ElapsedExpression): elapsed_condition = self.visit_elapsed_expression(child) - expression = py_trees.timers.Timer( - name=f"wait {elapsed_condition}s", duration=float(elapsed_condition)) + expression = py_trees.timers.Timer(name=f"wait {elapsed_condition}s", duration=float(elapsed_condition)) else: raise OSC2ParsingError( msg=f'Invalid event condition {child}', context=node.get_ctx()) return expression + def visit_relation_expression(self, node: RelationExpression): + return visit_expression(node, self.blackboard) + + def visit_logical_expression(self, node: LogicalExpression): + return visit_expression(node, self.blackboard) + + def visit_binary_expression(self, node: BinaryExpression): + return visit_expression(node, self.blackboard) + def visit_elapsed_expression(self, node: ElapsedExpression): elem = node.find_first_child_of_type(PhysicalLiteral) if not elem: @@ -389,3 +409,7 @@ def visit_modifier_invocation(self, node: ModifierInvocation): self.create_decorator(node.modifier, resolved_values) except ValueError as e: raise OSC2ParsingError(msg=f'ModifierDeclaration {e}.', context=node.get_ctx()) from e + + def visit_keep_constraint_declaration(self, node: KeepConstraintDeclaration): + # skip relation-expression + pass diff --git a/scenario_execution/scenario_execution/model/types.py b/scenario_execution/scenario_execution/model/types.py index e630b1b7..9903e826 100644 --- a/scenario_execution/scenario_execution/model/types.py +++ b/scenario_execution/scenario_execution/model/types.py @@ -18,6 +18,7 @@ from scenario_execution.model.error import OSC2ParsingError import sys import py_trees +import operator as op def print_tree(elem, logger, whitespace=""): @@ -338,16 +339,21 @@ def get_value_child(self): return None for child in self.get_children(): - if isinstance(child, (StringLiteral, FloatLiteral, BoolLiteral, IntegerLiteral, FunctionApplicationExpression, IdentifierReference, PhysicalLiteral, EnumValueReference, ListExpression)): + if isinstance(child, (StringLiteral, FloatLiteral, BoolLiteral, IntegerLiteral, FunctionApplicationExpression, IdentifierReference, PhysicalLiteral, EnumValueReference, ListExpression, BinaryExpression, RelationExpression, LogicalExpression)): return child + elif isinstance(child, KeepConstraintDeclaration): + pass + elif not isinstance(child, Type): + raise OSC2ParsingError(msg=f'Parameter has invalid value "{type(child).__name__}".', context=self.get_ctx()) return None def get_resolved_value(self, blackboard=None): param_type, is_list = self.get_type() vals = {} params = {} - if self.get_value_child(): - vals = self.get_value_child().get_resolved_value(blackboard) + val_child = self.get_value_child() + if val_child: + vals = val_child.get_resolved_value(blackboard) if isinstance(param_type, StructuredDeclaration) and not is_list: params = param_type.get_resolved_value(blackboard) @@ -447,7 +453,7 @@ def get_type_string(self): return self.name -class Expression(ModelElement): +class ModelExpression(ModelElement): pass @@ -1549,7 +1555,7 @@ def get_base_type(self): return self.modifier -class RiseExpression(Expression): +class RiseExpression(ModelExpression): def __init__(self): super().__init__() @@ -1569,7 +1575,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class FallExpression(Expression): +class FallExpression(ModelExpression): def __init__(self): super().__init__() @@ -1589,7 +1595,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class ElapsedExpression(Expression): +class ElapsedExpression(ModelExpression): def __init__(self): super().__init__() @@ -1609,7 +1615,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class EveryExpression(Expression): +class EveryExpression(ModelExpression): def __init__(self): super().__init__() @@ -1629,7 +1635,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class SampleExpression(Expression): +class SampleExpression(ModelExpression): def __init__(self): super().__init__() @@ -1649,7 +1655,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class CastExpression(Expression): +class CastExpression(ModelExpression): def __init__(self, object_def, target_type): super().__init__() @@ -1671,7 +1677,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class TypeTestExpression(Expression): +class TypeTestExpression(ModelExpression): def __init__(self, object_def, target_type): super().__init__() @@ -1693,7 +1699,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class ElementAccessExpression(Expression): +class ElementAccessExpression(ModelExpression): def __init__(self, list_name, index): super().__init__() @@ -1715,7 +1721,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class FunctionApplicationExpression(Expression): +class FunctionApplicationExpression(ModelExpression): def __init__(self, func_name): super().__init__() @@ -1775,7 +1781,7 @@ def get_type_string(self): return self.get_type()[0].name -class FieldAccessExpression(Expression): +class FieldAccessExpression(ModelExpression): def __init__(self, field_name): super().__init__() @@ -1796,7 +1802,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class BinaryExpression(Expression): +class BinaryExpression(ModelExpression): def __init__(self, operator): super().__init__() @@ -1816,8 +1822,24 @@ def accept(self, visitor): else: return visitor.visit_children(self) + def get_type_string(self): + type_string = None + for child in self.get_children(): + current = child.get_type_string() + if self.operator in ("/", "%", "*"): # multiplied by factor + if type_string is None or type_string in ("float", "int"): + type_string = current + else: + if type_string not in (current, type_string): + raise OSC2ParsingError(f'Children have different types {current}, {type_string}', context=self.get_ctx()) + type_string = current + return type_string + + def get_resolved_value(self, blackboard=None): + return visit_expression(self, blackboard).eval() + -class UnaryExpression(Expression): +class UnaryExpression(ModelExpression): def __init__(self, operator): super().__init__() @@ -1838,7 +1860,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class TernaryExpression(Expression): +class TernaryExpression(ModelExpression): def __init__(self): super().__init__() @@ -1858,7 +1880,7 @@ def accept(self, visitor): return visitor.visit_children(self) -class LogicalExpression(Expression): +class LogicalExpression(ModelExpression): def __init__(self, operator): super().__init__() @@ -1878,8 +1900,14 @@ def accept(self, visitor): else: return visitor.visit_children(self) + def get_type_string(self): + return "bool" + + def get_resolved_value(self, blackboard=None): + return visit_expression(self, blackboard).eval() -class RelationExpression(Expression): + +class RelationExpression(ModelExpression): def __init__(self, operator): super().__init__() @@ -1899,8 +1927,14 @@ def accept(self, visitor): else: return visitor.visit_children(self) + def get_type_string(self): + return "bool" + + def get_resolved_value(self, blackboard=None): + return visit_expression(self, blackboard).eval() + -class ListExpression(Expression): +class ListExpression(ModelExpression): def __init__(self): super().__init__() @@ -1935,7 +1969,7 @@ def get_resolved_value(self, blackboard=None): return value -class RangeExpression(Expression): +class RangeExpression(ModelExpression): def __init__(self): super().__init__() @@ -2198,6 +2232,8 @@ def get_type_string(self): def get_blackboard_reference(self, blackboard): if not isinstance(self.ref, list) or len(self.ref) == 0: raise ValueError("Variable Reference only supported if reference is list with at least one element") + if not isinstance(self.ref[0], ParameterDeclaration): + raise ValueError("Variable Reference only supported if reference is part of a parameter declaration") fqn = self.ref[0].get_fully_qualified_var_name(include_scenario=False) if blackboard is None: raise ValueError("Variable Reference found, but no blackboard client available.") @@ -2206,6 +2242,12 @@ def get_blackboard_reference(self, blackboard): blackboard.register_key(fqn, access=py_trees.common.Access.WRITE) return VariableReference(blackboard, fqn) + def get_variable_reference(self, blackboard): + if isinstance(self.ref, list) and any(isinstance(x, VariableDeclaration) for x in self.ref): + return self.get_blackboard_reference(blackboard) + else: + return None + def get_resolved_value(self, blackboard=None): if isinstance(self.ref, list): ref = self.ref[0] @@ -2222,3 +2264,86 @@ def get_resolved_value(self, blackboard=None): return val else: return self.ref.get_resolved_value(blackboard) + + +class Expression(object): + def __init__(self, left, right, operator) -> None: + self.left = left + self.right = right + self.operator = operator + + def resolve(self, param): + if isinstance(param, Expression): + return param.eval() + elif isinstance(param, VariableReference): + return param.get_value() + else: + return param + + def eval(self): + left = self.resolve(self.left) + if self.right is None: + return self.operator(left) + else: + right = self.resolve(self.right) + return self.operator(left, right) + + +def visit_expression(node, blackboard): + operator = None + single_child = False + if node.operator == "==": + operator = op.eq + elif node.operator == "!=": + operator = op.ne + elif node.operator == "<": + operator = op.lt + elif node.operator == "<=": + operator = op.le + elif node.operator == ">": + operator = op.gt + elif node.operator == ">=": + operator = op.ge + elif node.operator == "and": + operator = op.and_ + elif node.operator == "or": + operator = op.or_ + elif node.operator == "not": + single_child = True + operator = op.not_ + elif node.operator == "+": + operator = op.add + elif node.operator == "-": + operator = op.sub + elif node.operator == "*": + operator = op.mul + elif node.operator == "/": + operator = op.truediv + elif node.operator == "%": + operator = op.mod + else: + raise NotImplementedError(f"Unknown expression operator '{node.operator}'.") + + if not single_child and node.get_child_count() != 2: + raise ValueError("Expression is expected to have two children.") + + idx = 0 + args = [None, None] + for child in node.get_children(): + if isinstance(child, (RelationExpression, BinaryExpression, LogicalExpression)): + args[idx] = visit_expression(child, blackboard) + else: + if isinstance(child, IdentifierReference): + var_def = child.get_variable_reference(blackboard) + if var_def is not None: + args[idx] = var_def + else: + args[idx] = child.get_resolved_value(blackboard) + else: + args[idx] = child.get_resolved_value(blackboard) + idx += 1 + + if single_child: + return Expression(args[0], args[1], operator) + else: + return Expression(args[0], args[1], operator) diff --git a/scenario_execution/test/test_expression.py b/scenario_execution/test/test_expression.py index ab1a71b0..070273de 100644 --- a/scenario_execution/test/test_expression.py +++ b/scenario_execution/test/test_expression.py @@ -19,6 +19,7 @@ from scenario_execution.model.osc2_parser import OpenScenario2Parser from scenario_execution.utils.logging import Logger from antlr4.InputStream import InputStream +import py_trees class TestExpression(unittest.TestCase): @@ -29,8 +30,12 @@ class TestExpression(unittest.TestCase): def setUp(self) -> None: self.parser = OpenScenario2Parser(Logger('test', False)) + self.tree = py_trees.composites.Sequence(name="", memory=True) + + def parse(self, scenario_content): + parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) + return self.parser.create_internal_model(parsed_tree, self.tree, "test.osc", False) - @unittest.skip(reason="requires porting") def test_add(self): scenario_content = """ type time is SI(s: 1) @@ -41,14 +46,20 @@ def test_add(self): global test2: time = 2.0s + 1.1s global test3: time = 2.0s + 1ms """ - parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) - model = self.parser.create_internal_model(parsed_tree, "test.osc", False) + model = self.parse(scenario_content) + self.assertAlmostEqual(model._ModelElement__children[3].get_resolved_value(), 3.1) + self.assertAlmostEqual(model._ModelElement__children[4].get_resolved_value(), 3.1) + self.assertAlmostEqual(model._ModelElement__children[5].get_resolved_value(), 2.001) + + def test_add_different_types(self): + scenario_content = """ +type time is SI(s: 1) +unit s of time is SI(s: 1, factor: 1) - self.assertEqual(model._ModelElement__children[3].get_resolved_value(), 3.1) - self.assertEqual(model._ModelElement__children[4].get_resolved_value(), 3.1) - self.assertEqual(model._ModelElement__children[5].get_resolved_value(), 2.001) +global test2: time = 2.0s + 1.1 +""" + self.assertRaises(ValueError, self.parse, scenario_content) - @unittest.skip(reason="requires porting") def test_substract(self): scenario_content = """ type time is SI(s: 1) @@ -59,14 +70,20 @@ def test_substract(self): global test2: time = 2.0s - 1.1s global test3: time = 2.0s - 1ms """ - parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) - model = self.parser.create_internal_model(parsed_tree, "test.osc", False) + model = self.parse(scenario_content) + self.assertAlmostEqual(model._ModelElement__children[3].get_resolved_value(), 0.9) + self.assertAlmostEqual(model._ModelElement__children[4].get_resolved_value(), 0.9) + self.assertAlmostEqual(model._ModelElement__children[5].get_resolved_value(), 1.999) + + def test_substract_different_types(self): + scenario_content = """ +type time is SI(s: 1) +unit s of time is SI(s: 1, factor: 1) - self.assertEqual(model._ModelElement__children[3].get_resolved_value(), 0.9) - self.assertEqual(model._ModelElement__children[4].get_resolved_value(), 0.9) - self.assertEqual(model._ModelElement__children[5].get_resolved_value(), 1.999) +global test2: time = 2.0s - 1.1 +""" + self.assertRaises(ValueError, self.parse, scenario_content) - @unittest.skip(reason="requires porting") def test_multiply(self): scenario_content = """ type time is SI(s: 1) @@ -75,13 +92,10 @@ def test_multiply(self): global test1: float = 2.0 * 1.1 global test2: time = 2.0ms * 1.1 """ - parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) - model = self.parser.create_internal_model(parsed_tree, "test.osc", False) + model = self.parse(scenario_content) + self.assertAlmostEqual(model._ModelElement__children[2].get_resolved_value(), 2.2) + self.assertAlmostEqual(model._ModelElement__children[3].get_resolved_value(), 0.0022) - self.assertEqual(model._ModelElement__children[3].get_resolved_value(), 2.2) - self.assertEqual(model._ModelElement__children[4].get_resolved_value(), 0.0022) - - @unittest.skip(reason="requires porting") def test_divide(self): scenario_content = """ type time is SI(s: 1) @@ -90,13 +104,10 @@ def test_divide(self): global test1: float = 5.0 / 2.0 global test2: time = 5.0ms / 2.0 """ - parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) - model = self.parser.create_internal_model(parsed_tree, "test.osc", False) - - self.assertEqual(model._ModelElement__children[3].get_resolved_value(), 2.5) - self.assertEqual(model._ModelElement__children[4].get_resolved_value(), 0.0025) + model = self.parse(scenario_content) + self.assertAlmostEqual(model._ModelElement__children[2].get_resolved_value(), 2.5) + self.assertAlmostEqual(model._ModelElement__children[3].get_resolved_value(), 0.0025) - @unittest.skip(reason="requires porting") def test_relation(self): scenario_content = """ type time is SI(s: 1) @@ -110,37 +121,31 @@ def test_relation(self): global test6: bool = 5 >= 2 global test7: bool = 5 <= 2 """ - parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) - model = self.parser.create_internal_model(parsed_tree, "test.osc", False) - + model = self.parse(scenario_content) + self.assertEqual(model._ModelElement__children[2].get_resolved_value(), True) self.assertEqual(model._ModelElement__children[3].get_resolved_value(), True) - self.assertEqual(model._ModelElement__children[4].get_resolved_value(), True) + self.assertEqual(model._ModelElement__children[4].get_resolved_value(), False) self.assertEqual(model._ModelElement__children[5].get_resolved_value(), False) - self.assertEqual(model._ModelElement__children[6].get_resolved_value(), False) + self.assertEqual(model._ModelElement__children[6].get_resolved_value(), True) self.assertEqual(model._ModelElement__children[7].get_resolved_value(), True) - self.assertEqual(model._ModelElement__children[8].get_resolved_value(), True) - self.assertEqual(model._ModelElement__children[9].get_resolved_value(), False) + self.assertEqual(model._ModelElement__children[8].get_resolved_value(), False) - @unittest.skip(reason="requires porting") def test_negation(self): scenario_content = """ -global test1: bool = not True +global test1: bool = not true global test1: bool = not 5 > 2 +global test1: bool = not false """ - parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) - model = self.parser.create_internal_model(parsed_tree, "test.osc", False) - + model = self.parse(scenario_content) self.assertEqual(model._ModelElement__children[0].get_resolved_value(), False) self.assertEqual(model._ModelElement__children[1].get_resolved_value(), False) + self.assertEqual(model._ModelElement__children[2].get_resolved_value(), True) - @unittest.skip(reason="requires porting") def test_compound_expression(self): scenario_content = """ global test1: bool = 2 > 1 and 3 >= 2 global test1: bool = 2 > 1 or 3 < 2 """ - parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) - model = self.parser.create_internal_model(parsed_tree, "test.osc", False) - + model = self.parse(scenario_content) self.assertEqual(model._ModelElement__children[0].get_resolved_value(), True) self.assertEqual(model._ModelElement__children[1].get_resolved_value(), True) diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_check_data.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_check_data.py index 5c07db2f..d1c1255c 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_check_data.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_check_data.py @@ -125,7 +125,7 @@ def check_data(self, msg): try: value = check_attr(msg) except AttributeError: - self.feedback_message = "Member name not found {self.member_name}]" + self.feedback_message = f"Member name not found {self.member_name}" self.found = self.comparison_operator(value, self.expected_value) def set_expected_value(self, expected_value_string): diff --git a/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_monitor.py b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_monitor.py index a07423a3..1279c0d5 100644 --- a/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_monitor.py +++ b/scenario_execution_ros/scenario_execution_ros/actions/ros_topic_monitor.py @@ -19,13 +19,15 @@ from scenario_execution.model.types import VariableReference import rclpy import py_trees +import operator class RosTopicMonitor(BaseAction): - def __init__(self, topic_name: str, topic_type: str, target_variable: object, qos_profile: tuple): + def __init__(self, topic_name: str, topic_type: str, member_name: str, target_variable: object, qos_profile: tuple): super().__init__(resolve_variable_reference_arguments_in_execute=False) self.target_variable = None + self.member_name = member_name self.topic_type = topic_type self.qos_profile = qos_profile self.topic_name = topic_name @@ -43,8 +45,13 @@ def setup(self, **kwargs): self.name, self.__class__.__name__) raise KeyError(error_message) from e + msg_type = get_ros_message_type(self.topic_type) + + # check if member-name exists + self.get_value(msg_type()) + self.subscriber = self.node.create_subscription( - msg_type=get_ros_message_type(self.topic_type), + msg_type=msg_type, topic=self.topic_name, callback=self._callback, qos_profile=get_qos_preset_profile(self.qos_profile), @@ -52,16 +59,27 @@ def setup(self, **kwargs): ) self.feedback_message = f"Monitoring data on {self.topic_name}" # pylint: disable= attribute-defined-outside-init - def execute(self, topic_name, topic_type, target_variable, qos_profile): + def execute(self, topic_name: str, topic_type: str, member_name: str, target_variable: object, qos_profile: tuple): if self.topic_name != topic_name or self.topic_type != topic_type or self.qos_profile != qos_profile: raise ValueError("Updating topic parameters not supported.") if not isinstance(target_variable, VariableReference): raise ValueError(f"'target_variable' is expected to be a variable reference.") self.target_variable = target_variable + self.member_name = member_name def update(self) -> py_trees.common.Status: return py_trees.common.Status.SUCCESS def _callback(self, msg): if self.target_variable is not None: - self.target_variable.set_value(msg) + self.target_variable.set_value(self.get_value(msg)) + + def get_value(self, msg): + if self.member_name != "": + check_attr = operator.attrgetter(self.member_name) + try: + return check_attr(msg) + except AttributeError as e: + raise ValueError(f"invalid member_name '{self.member_name}'") from e + else: + return msg diff --git a/scenario_execution_ros/scenario_execution_ros/lib_osc/ros.osc b/scenario_execution_ros/scenario_execution_ros/lib_osc/ros.osc index 776c2463..548c6a79 100644 --- a/scenario_execution_ros/scenario_execution_ros/lib_osc/ros.osc +++ b/scenario_execution_ros/scenario_execution_ros/lib_osc/ros.osc @@ -128,6 +128,7 @@ action topic_monitor: topic_type: string # class of the message type (e.g. std_msgs.msg.String) target_variable: string # name of the variable (e.g. a 'var' within an actor instance) qos_profile: qos_preset_profiles = qos_preset_profiles!system_default # qos profile for the subscriber + member_name: string = "" # if not empty, only the value of the member is stored within the variable action topic_publish: # publish a message on a topic diff --git a/test/scenario_execution_ros_test/test/test_topic_monitor.py b/test/scenario_execution_ros_test/test/test_topic_monitor.py index 1d62138a..96ecd372 100644 --- a/test/scenario_execution_ros_test/test/test_topic_monitor.py +++ b/test/scenario_execution_ros_test/test/test_topic_monitor.py @@ -90,3 +90,84 @@ def test_success(self): with open(self.tmp_file.name) as f: result = f.read() self.assertEqual(result, "std_msgs.msg.String(data='TEST')") + + def test_member_success(self): + scenario_content = """ +import osc.ros + +action store_action: + file_path: string + value: string + +actor test_actor: + var test: string = "one" + +scenario test_scenario: + foo: test_actor + + do parallel: + serial: + wait elapsed(1s) + topic_publish("/bla", "std_msgs.msg.String", '{\\\"data\\\": \\\"TEST\\\"}') + serial: + topic_monitor("/bla", "std_msgs.msg.String", foo.test, member_name: "data") + wait elapsed(2s) + store_action('""" + self.tmp_file.name + """', foo.test) +""" + self.execute(scenario_content) + self.assertTrue(self.scenario_execution_ros.process_results()) + with open(self.tmp_file.name) as f: + result = f.read() + self.assertEqual(result, "TEST") + + def test_member_unknown(self): + scenario_content = """ +import osc.ros + +action store_action: + file_path: string + value: string + +actor test_actor: + var test: string = "one" + +scenario test_scenario: + foo: test_actor + + do parallel: + serial: + wait elapsed(1s) + topic_publish("/bla", "std_msgs.msg.String", '{\\\"data\\\": \\\"TEST\\\"}') + serial: + topic_monitor("/bla", "std_msgs.msg.String", foo.test, member_name: "UNKNOWN") + wait elapsed(2s) + store_action('""" + self.tmp_file.name + """', foo.test) +""" + self.execute(scenario_content) + self.assertFalse(self.scenario_execution_ros.process_results()) + + def test_member_relation_expr_success(self): + scenario_content = """ +import osc.ros +import osc.helpers + +struct current_state: + var message_count: int = 1 + +scenario test_scenario: + timeout(10s) + current: current_state + do serial: + parallel: + serial: + repeat() + wait elapsed(1s) + topic_publish("/bla", "std_msgs.msg.Int64", '{\\\"data\\\": 2}') + topic_monitor("/bla", "std_msgs.msg.Int64", member_name: "data", target_variable: current.message_count) + + serial: + wait current.message_count == 2 + emit end +""" + self.execute(scenario_content) + self.assertTrue(self.scenario_execution_ros.process_results()) diff --git a/test/scenario_execution_test/scenario_execution_test/actions/set_blackboard_var.py b/test/scenario_execution_test/scenario_execution_test/actions/set_blackboard_var.py new file mode 100644 index 00000000..e8b003a6 --- /dev/null +++ b/test/scenario_execution_test/scenario_execution_test/actions/set_blackboard_var.py @@ -0,0 +1,31 @@ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import py_trees +from py_trees.common import Status +from scenario_execution.actions.base_action import BaseAction + + +class SetBlackboardVariable(BaseAction): + + def execute(self, variable_name: str, variable_value): + self.variable_name = variable_name + self.variable_value = variable_value + self.get_blackboard_client().register_key(self.variable_name, access=py_trees.common.Access.WRITE) + + def update(self) -> py_trees.common.Status: + self.get_blackboard_client().set(self.variable_name, self.variable_value) + return Status.SUCCESS diff --git a/test/scenario_execution_test/setup.py b/test/scenario_execution_test/setup.py index 9acd74d4..24524562 100644 --- a/test/scenario_execution_test/setup.py +++ b/test/scenario_execution_test/setup.py @@ -42,5 +42,6 @@ 'scenario_execution.actions': [ 'test_actor.set_value = scenario_execution_test.actions.actor_set_value:ActorSetValue', 'store_action = scenario_execution_test.actions.store_action:StoreAction', + 'set_blackboard_var = scenario_execution_test.actions.set_blackboard_var:SetBlackboardVariable', ]} ) diff --git a/test/scenario_execution_test/test/test_expression_with_var.py b/test/scenario_execution_test/test/test_expression_with_var.py new file mode 100644 index 00000000..4e358aba --- /dev/null +++ b/test/scenario_execution_test/test/test_expression_with_var.py @@ -0,0 +1,98 @@ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import unittest +import tempfile +import py_trees +from scenario_execution import ScenarioExecution +from scenario_execution.model.osc2_parser import OpenScenario2Parser +from scenario_execution.model.model_to_py_tree import create_py_tree +from scenario_execution.model.model_blackboard import create_py_tree_blackboard +from scenario_execution.utils.logging import Logger + +from antlr4.InputStream import InputStream + + +class TestCheckData(unittest.TestCase): + # pylint: disable=missing-function-docstring,missing-class-docstring + + def setUp(self) -> None: + self.parser = OpenScenario2Parser(Logger('test', False)) + self.scenario_execution = ScenarioExecution(debug=False, log_model=False, live_tree=False, + scenario_file="test.osc", output_dir=None) + self.tree = py_trees.composites.Sequence(name="", memory=True) + self.tmp_file = tempfile.NamedTemporaryFile() + + def execute(self, scenario_content): + parsed_tree = self.parser.parse_input_stream(InputStream(scenario_content)) + model = self.parser.create_internal_model(parsed_tree, self.tree, "test.osc", False) + create_py_tree_blackboard(model, self.tree, self.parser.logger, False) + self.tree = create_py_tree(model, self.tree, self.parser.logger, False) + self.scenario_execution.tree = self.tree + self.scenario_execution.run() + + def test_success(self): + scenario_content = """ +import osc.helpers + +struct current_state: + var val: int = 1 + +action set_blackboard_var: + variable_name: string + variable_value: string + +scenario test_scenario: + timeout(5s) + current: current_state + do parallel: + serial: + wait elapsed(0.2s) + set_blackboard_var("current/val", 2) + wait elapsed(10s) + serial: + wait current.val * 2 + 4 - 4 / 2 == 6 + emit end +""" + self.execute(scenario_content) + self.assertTrue(self.scenario_execution.process_results()) + + def test_success_not(self): + scenario_content = """ +import osc.helpers + +struct current_state: + var val: bool = false + var val2: bool = false + +action set_blackboard_var: + variable_name: string + variable_value: string + +scenario test_scenario: + timeout(5s) + current: current_state + do parallel: + serial: + wait elapsed(0.2s) + set_blackboard_var("current/val", true) + wait elapsed(0.2s) + serial: + wait current.val and not current.val2 + emit end +""" + self.execute(scenario_content) + self.assertTrue(self.scenario_execution.process_results())