From 6b2582533a86f8fd3f43813868f900151f81c753 Mon Sep 17 00:00:00 2001 From: mathislucka Date: Wed, 15 Jan 2025 15:25:21 +0100 Subject: [PATCH] still wait for optional inputs on greedy variadic sockets - mirrors previous behavior --- haystack/core/pipeline/component_checks.py | 47 +++++++++++-------- haystack/core/pipeline/pipeline.py | 3 +- test/core/pipeline/test_component_checks.py | 50 ++++++++++++++------- test/core/pipeline/test_pipeline.py | 16 +++++++ 4 files changed, 81 insertions(+), 35 deletions(-) diff --git a/haystack/core/pipeline/component_checks.py b/haystack/core/pipeline/component_checks.py index 30aaa153cb..ee6bf15deb 100644 --- a/haystack/core/pipeline/component_checks.py +++ b/haystack/core/pipeline/component_checks.py @@ -19,7 +19,7 @@ def can_component_run(component: Dict, inputs: Dict) -> bool: :param component: Component metadata and the component instance. :param inputs: Inputs for the component. """ - received_all_mandatory_inputs = are_all_mandatory_sockets_ready(component, inputs) + received_all_mandatory_inputs = are_all_sockets_ready(component, inputs, only_check_mandatory=True) received_trigger = has_any_trigger(component, inputs) return received_all_mandatory_inputs and received_trigger @@ -49,27 +49,38 @@ def has_any_trigger(component: Dict, inputs: Dict) -> bool: return trigger_from_predecessor or trigger_from_user or trigger_without_inputs -def are_all_mandatory_sockets_ready(component: Dict, inputs: Dict) -> bool: +def are_all_sockets_ready(component: Dict, inputs: Dict, only_check_mandatory: bool = False) -> bool: """ - Checks if all mandatory sockets of a component have enough inputs for the component to execute. + Checks if all sockets of a component have enough inputs for the component to execute. :param component: Component metadata and the component instance. :param inputs: Inputs for the component. - """ - filled_mandatory_sockets = set() - expected_mandatory_sockets = set() - for socket_name, socket in component["input_sockets"].items(): - if socket.is_mandatory: - socket_inputs = inputs.get(socket_name, []) - expected_mandatory_sockets.add(socket_name) - if ( - is_socket_lazy_variadic(socket) - and any_socket_input_received(socket_inputs) - or has_socket_received_all_inputs(socket, socket_inputs) - ): - filled_mandatory_sockets.add(socket_name) - - return filled_mandatory_sockets == expected_mandatory_sockets + :param only_check_mandatory: If only mandatory sockets should be checked. + """ + filled_sockets = set() + expected_sockets = set() + if only_check_mandatory: + sockets_to_check = { + socket_name: socket for socket_name, socket in component["input_sockets"].items() if socket.is_mandatory + } + else: + sockets_to_check = { + socket_name: socket + for socket_name, socket in component["input_sockets"].items() + if socket.is_mandatory or len(socket.senders) + } + + for socket_name, socket in sockets_to_check.items(): + socket_inputs = inputs.get(socket_name, []) + expected_sockets.add(socket_name) + if ( + is_socket_lazy_variadic(socket) + and any_socket_input_received(socket_inputs) + or has_socket_received_all_inputs(socket, socket_inputs) + ): + filled_sockets.add(socket_name) + + return filled_sockets == expected_sockets def any_predecessors_provided_input(component: Dict, inputs: Dict) -> bool: diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index 2fdb0bc7e3..fc44f40451 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -13,6 +13,7 @@ from haystack.core.pipeline.component_checks import ( _NO_OUTPUT_PRODUCED, all_predecessors_executed, + are_all_sockets_ready, are_all_lazy_variadic_sockets_resolved, can_component_run, is_any_greedy_socket_ready, @@ -205,7 +206,7 @@ def _calculate_priority(component: Dict, inputs: Dict) -> ComponentPriority: """ if not can_component_run(component, inputs): return ComponentPriority.BLOCKED - elif is_any_greedy_socket_ready(component, inputs): + elif is_any_greedy_socket_ready(component, inputs) and are_all_sockets_ready(component, inputs): return ComponentPriority.HIGHEST elif all_predecessors_executed(component, inputs): return ComponentPriority.READY diff --git a/test/core/pipeline/test_component_checks.py b/test/core/pipeline/test_component_checks.py index c6ce9f2ba9..faa74b1eb4 100644 --- a/test/core/pipeline/test_component_checks.py +++ b/test/core/pipeline/test_component_checks.py @@ -13,7 +13,7 @@ def basic_component(): "instance": "mock_instance", "visits": 0, "input_sockets": { - "mandatory_input": InputSocket("mandatory_input", int), + "mandatory_input": InputSocket("mandatory_input", int, senders=["previous_component"]), "optional_input": InputSocket("optional_input", str, default_value="default"), }, "output_sockets": {"output": OutputSocket("output", int)}, @@ -27,8 +27,8 @@ def variadic_component(): "instance": "mock_instance", "visits": 0, "input_sockets": { - "variadic_input": InputSocket("variadic_input", Variadic[int]), - "normal_input": InputSocket("normal_input", str), + "variadic_input": InputSocket("variadic_input", Variadic[int], senders=["previous_component"]), + "normal_input": InputSocket("normal_input", str, senders=["another_component"]), }, "output_sockets": {"output": OutputSocket("output", int)}, } @@ -41,7 +41,9 @@ def greedy_variadic_component(): "instance": "mock_instance", "visits": 0, "input_sockets": { - "greedy_input": InputSocket("greedy_input", GreedyVariadic[int]), + "greedy_input": InputSocket( + "greedy_input", GreedyVariadic[int], senders=["previous_component", "other_component"] + ), "normal_input": InputSocket("normal_input", str), }, "output_sockets": {"output": OutputSocket("output", int)}, @@ -194,12 +196,12 @@ class TestAllMandatorySocketsReady: def test_all_mandatory_sockets_filled(self, basic_component): """Checks that all mandatory sockets are ready when they have valid input.""" inputs = {"mandatory_input": [{"sender": "previous_component", "value": 42}]} - assert are_all_mandatory_sockets_ready(basic_component, inputs) is True + assert are_all_sockets_ready(basic_component, inputs) is True def test_missing_mandatory_socket(self, basic_component): """Ensures that if a mandatory socket is missing, the component is not ready.""" inputs = {"optional_input": [{"sender": "previous_component", "value": "test"}]} - assert are_all_mandatory_sockets_ready(basic_component, inputs) is False + assert are_all_sockets_ready(basic_component, inputs) is False def test_variadic_socket_with_input(self, variadic_component): """Verifies that a variadic socket is considered filled if it has at least one input.""" @@ -207,25 +209,41 @@ def test_variadic_socket_with_input(self, variadic_component): "variadic_input": [{"sender": "previous_component", "value": 42}], "normal_input": [{"sender": "previous_component", "value": "test"}], } - assert are_all_mandatory_sockets_ready(variadic_component, inputs) is True + assert are_all_sockets_ready(variadic_component, inputs) is True - def test_greedy_variadic_socket_with_partial_input(self, greedy_variadic_component): + def test_greedy_variadic_socket(self, greedy_variadic_component): """Greedy variadic sockets are ready with at least one valid input.""" inputs = { "greedy_input": [{"sender": "previous_component", "value": 42}], "normal_input": [{"sender": "previous_component", "value": "test"}], } - assert are_all_mandatory_sockets_ready(greedy_variadic_component, inputs) is True + assert are_all_sockets_ready(greedy_variadic_component, inputs) is True + + def test_greedy_variadic_socket_and_missing_mandatory(self, greedy_variadic_component): + """All mandatory sockets need to be filled even with GreedyVariadic sockets.""" + inputs = {"greedy_input": [{"sender": "previous_component", "value": 42}]} + assert are_all_sockets_ready(greedy_variadic_component, inputs, only_check_mandatory=True) is False def test_variadic_socket_no_input(self, variadic_component): """A variadic socket is not filled if it has zero valid inputs.""" inputs = {"normal_input": [{"sender": "previous_component", "value": "test"}]} - assert are_all_mandatory_sockets_ready(variadic_component, inputs) is False + assert are_all_sockets_ready(variadic_component, inputs) is False + + def test_mandatory_and_optional_sockets(self): + input_sockets = { + "mandatory": InputSocket("mandatory", str, senders=["previous_component"]), + "optional": InputSocket("optional", str, senders=["previous_component"], default_value="test"), + } + + component = {"input_sockets": input_sockets} + inputs = {"mandatory": [{"sender": "previous_component", "value": "hello"}]} + assert are_all_sockets_ready(component, inputs) is False + assert are_all_sockets_ready(component, inputs, only_check_mandatory=True) is True def test_empty_inputs(self, basic_component): """Checks that if there are no inputs at all, mandatory sockets are not ready.""" inputs = {} - assert are_all_mandatory_sockets_ready(basic_component, inputs) is False + assert are_all_sockets_ready(basic_component, inputs) is False def test_no_mandatory_sockets(self, basic_component): """Ensures that if there are no mandatory sockets, the component is considered ready.""" @@ -234,24 +252,24 @@ def test_no_mandatory_sockets(self, basic_component): "optional_2": InputSocket("optional_2", str, default_value="default2"), } inputs = {} - assert are_all_mandatory_sockets_ready(basic_component, inputs) is True + assert are_all_sockets_ready(basic_component, inputs) is True def test_multiple_mandatory_sockets(self, basic_component): """Checks readiness when multiple mandatory sockets are defined.""" basic_component["input_sockets"] = { - "mandatory_1": InputSocket("mandatory_1", int), - "mandatory_2": InputSocket("mandatory_2", str), + "mandatory_1": InputSocket("mandatory_1", int, senders=["previous_component"]), + "mandatory_2": InputSocket("mandatory_2", str, senders=["some other component"]), "optional": InputSocket("optional", bool, default_value=False), } inputs = { "mandatory_1": [{"sender": "comp1", "value": 42}], "mandatory_2": [{"sender": "comp2", "value": "test"}], } - assert are_all_mandatory_sockets_ready(basic_component, inputs) is True + assert are_all_sockets_ready(basic_component, inputs) is True # Missing one mandatory input inputs = {"mandatory_1": [{"sender": "comp1", "value": 42}], "optional": [{"sender": "comp3", "value": True}]} - assert are_all_mandatory_sockets_ready(basic_component, inputs) is False + assert are_all_sockets_ready(basic_component, inputs) is False class TestPredecessorInputDetection: diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index d24b20425e..70e183b5b9 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -1210,6 +1210,22 @@ def test__find_receivers_from(self): ComponentPriority.HIGHEST, "Component should have HIGHEST priority when greedy socket has valid input", ), + # Test case 4: DEFER - Greedy socket ready but optional missing + ( + { + "instance": "mock_instance", + "visits": 0, + "input_sockets": { + "greedy_input": InputSocket("greedy_input", GreedyVariadic[int], senders=["component1"]), + "optional_input": InputSocket( + "optional_input", str, senders=["component2"], default_value="test" + ), + }, + }, + {"greedy_input": [{"sender": "component1", "value": 42}]}, + ComponentPriority.DEFER, + "Component should DEFER when greedy socket has valid input but expected optional input is missing", + ), # Test case 4: READY - All predecessors executed ( {