Skip to content

Commit

Permalink
still wait for optional inputs on greedy variadic sockets
Browse files Browse the repository at this point in the history
- mirrors previous behavior
  • Loading branch information
mathislucka committed Jan 15, 2025
1 parent b2b8adc commit 6b25825
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 35 deletions.
47 changes: 29 additions & 18 deletions haystack/core/pipeline/component_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def can_component_run(component: Dict, inputs: Dict) -> bool:
:param component: Component metadata and the component instance.
:param inputs: Inputs for the component.
"""
received_all_mandatory_inputs = are_all_mandatory_sockets_ready(component, inputs)
received_all_mandatory_inputs = are_all_sockets_ready(component, inputs, only_check_mandatory=True)
received_trigger = has_any_trigger(component, inputs)

return received_all_mandatory_inputs and received_trigger
Expand Down Expand Up @@ -49,27 +49,38 @@ def has_any_trigger(component: Dict, inputs: Dict) -> bool:
return trigger_from_predecessor or trigger_from_user or trigger_without_inputs


def are_all_mandatory_sockets_ready(component: Dict, inputs: Dict) -> bool:
def are_all_sockets_ready(component: Dict, inputs: Dict, only_check_mandatory: bool = False) -> bool:
"""
Checks if all mandatory sockets of a component have enough inputs for the component to execute.
Checks if all sockets of a component have enough inputs for the component to execute.
:param component: Component metadata and the component instance.
:param inputs: Inputs for the component.
"""
filled_mandatory_sockets = set()
expected_mandatory_sockets = set()
for socket_name, socket in component["input_sockets"].items():
if socket.is_mandatory:
socket_inputs = inputs.get(socket_name, [])
expected_mandatory_sockets.add(socket_name)
if (
is_socket_lazy_variadic(socket)
and any_socket_input_received(socket_inputs)
or has_socket_received_all_inputs(socket, socket_inputs)
):
filled_mandatory_sockets.add(socket_name)

return filled_mandatory_sockets == expected_mandatory_sockets
:param only_check_mandatory: If only mandatory sockets should be checked.
"""
filled_sockets = set()
expected_sockets = set()
if only_check_mandatory:
sockets_to_check = {
socket_name: socket for socket_name, socket in component["input_sockets"].items() if socket.is_mandatory
}
else:
sockets_to_check = {
socket_name: socket
for socket_name, socket in component["input_sockets"].items()
if socket.is_mandatory or len(socket.senders)
}

for socket_name, socket in sockets_to_check.items():
socket_inputs = inputs.get(socket_name, [])
expected_sockets.add(socket_name)
if (
is_socket_lazy_variadic(socket)
and any_socket_input_received(socket_inputs)
or has_socket_received_all_inputs(socket, socket_inputs)
):
filled_sockets.add(socket_name)

return filled_sockets == expected_sockets


def any_predecessors_provided_input(component: Dict, inputs: Dict) -> bool:
Expand Down
3 changes: 2 additions & 1 deletion haystack/core/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from haystack.core.pipeline.component_checks import (
_NO_OUTPUT_PRODUCED,
all_predecessors_executed,
are_all_sockets_ready,
are_all_lazy_variadic_sockets_resolved,
can_component_run,
is_any_greedy_socket_ready,
Expand Down Expand Up @@ -205,7 +206,7 @@ def _calculate_priority(component: Dict, inputs: Dict) -> ComponentPriority:
"""
if not can_component_run(component, inputs):
return ComponentPriority.BLOCKED
elif is_any_greedy_socket_ready(component, inputs):
elif is_any_greedy_socket_ready(component, inputs) and are_all_sockets_ready(component, inputs):
return ComponentPriority.HIGHEST
elif all_predecessors_executed(component, inputs):
return ComponentPriority.READY
Expand Down
50 changes: 34 additions & 16 deletions test/core/pipeline/test_component_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def basic_component():
"instance": "mock_instance",
"visits": 0,
"input_sockets": {
"mandatory_input": InputSocket("mandatory_input", int),
"mandatory_input": InputSocket("mandatory_input", int, senders=["previous_component"]),
"optional_input": InputSocket("optional_input", str, default_value="default"),
},
"output_sockets": {"output": OutputSocket("output", int)},
Expand All @@ -27,8 +27,8 @@ def variadic_component():
"instance": "mock_instance",
"visits": 0,
"input_sockets": {
"variadic_input": InputSocket("variadic_input", Variadic[int]),
"normal_input": InputSocket("normal_input", str),
"variadic_input": InputSocket("variadic_input", Variadic[int], senders=["previous_component"]),
"normal_input": InputSocket("normal_input", str, senders=["another_component"]),
},
"output_sockets": {"output": OutputSocket("output", int)},
}
Expand All @@ -41,7 +41,9 @@ def greedy_variadic_component():
"instance": "mock_instance",
"visits": 0,
"input_sockets": {
"greedy_input": InputSocket("greedy_input", GreedyVariadic[int]),
"greedy_input": InputSocket(
"greedy_input", GreedyVariadic[int], senders=["previous_component", "other_component"]
),
"normal_input": InputSocket("normal_input", str),
},
"output_sockets": {"output": OutputSocket("output", int)},
Expand Down Expand Up @@ -194,38 +196,54 @@ class TestAllMandatorySocketsReady:
def test_all_mandatory_sockets_filled(self, basic_component):
"""Checks that all mandatory sockets are ready when they have valid input."""
inputs = {"mandatory_input": [{"sender": "previous_component", "value": 42}]}
assert are_all_mandatory_sockets_ready(basic_component, inputs) is True
assert are_all_sockets_ready(basic_component, inputs) is True

def test_missing_mandatory_socket(self, basic_component):
"""Ensures that if a mandatory socket is missing, the component is not ready."""
inputs = {"optional_input": [{"sender": "previous_component", "value": "test"}]}
assert are_all_mandatory_sockets_ready(basic_component, inputs) is False
assert are_all_sockets_ready(basic_component, inputs) is False

def test_variadic_socket_with_input(self, variadic_component):
"""Verifies that a variadic socket is considered filled if it has at least one input."""
inputs = {
"variadic_input": [{"sender": "previous_component", "value": 42}],
"normal_input": [{"sender": "previous_component", "value": "test"}],
}
assert are_all_mandatory_sockets_ready(variadic_component, inputs) is True
assert are_all_sockets_ready(variadic_component, inputs) is True

def test_greedy_variadic_socket_with_partial_input(self, greedy_variadic_component):
def test_greedy_variadic_socket(self, greedy_variadic_component):
"""Greedy variadic sockets are ready with at least one valid input."""
inputs = {
"greedy_input": [{"sender": "previous_component", "value": 42}],
"normal_input": [{"sender": "previous_component", "value": "test"}],
}
assert are_all_mandatory_sockets_ready(greedy_variadic_component, inputs) is True
assert are_all_sockets_ready(greedy_variadic_component, inputs) is True

def test_greedy_variadic_socket_and_missing_mandatory(self, greedy_variadic_component):
"""All mandatory sockets need to be filled even with GreedyVariadic sockets."""
inputs = {"greedy_input": [{"sender": "previous_component", "value": 42}]}
assert are_all_sockets_ready(greedy_variadic_component, inputs, only_check_mandatory=True) is False

def test_variadic_socket_no_input(self, variadic_component):
"""A variadic socket is not filled if it has zero valid inputs."""
inputs = {"normal_input": [{"sender": "previous_component", "value": "test"}]}
assert are_all_mandatory_sockets_ready(variadic_component, inputs) is False
assert are_all_sockets_ready(variadic_component, inputs) is False

def test_mandatory_and_optional_sockets(self):
input_sockets = {
"mandatory": InputSocket("mandatory", str, senders=["previous_component"]),
"optional": InputSocket("optional", str, senders=["previous_component"], default_value="test"),
}

component = {"input_sockets": input_sockets}
inputs = {"mandatory": [{"sender": "previous_component", "value": "hello"}]}
assert are_all_sockets_ready(component, inputs) is False
assert are_all_sockets_ready(component, inputs, only_check_mandatory=True) is True

def test_empty_inputs(self, basic_component):
"""Checks that if there are no inputs at all, mandatory sockets are not ready."""
inputs = {}
assert are_all_mandatory_sockets_ready(basic_component, inputs) is False
assert are_all_sockets_ready(basic_component, inputs) is False

def test_no_mandatory_sockets(self, basic_component):
"""Ensures that if there are no mandatory sockets, the component is considered ready."""
Expand All @@ -234,24 +252,24 @@ def test_no_mandatory_sockets(self, basic_component):
"optional_2": InputSocket("optional_2", str, default_value="default2"),
}
inputs = {}
assert are_all_mandatory_sockets_ready(basic_component, inputs) is True
assert are_all_sockets_ready(basic_component, inputs) is True

def test_multiple_mandatory_sockets(self, basic_component):
"""Checks readiness when multiple mandatory sockets are defined."""
basic_component["input_sockets"] = {
"mandatory_1": InputSocket("mandatory_1", int),
"mandatory_2": InputSocket("mandatory_2", str),
"mandatory_1": InputSocket("mandatory_1", int, senders=["previous_component"]),
"mandatory_2": InputSocket("mandatory_2", str, senders=["some other component"]),
"optional": InputSocket("optional", bool, default_value=False),
}
inputs = {
"mandatory_1": [{"sender": "comp1", "value": 42}],
"mandatory_2": [{"sender": "comp2", "value": "test"}],
}
assert are_all_mandatory_sockets_ready(basic_component, inputs) is True
assert are_all_sockets_ready(basic_component, inputs) is True

# Missing one mandatory input
inputs = {"mandatory_1": [{"sender": "comp1", "value": 42}], "optional": [{"sender": "comp3", "value": True}]}
assert are_all_mandatory_sockets_ready(basic_component, inputs) is False
assert are_all_sockets_ready(basic_component, inputs) is False


class TestPredecessorInputDetection:
Expand Down
16 changes: 16 additions & 0 deletions test/core/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,22 @@ def test__find_receivers_from(self):
ComponentPriority.HIGHEST,
"Component should have HIGHEST priority when greedy socket has valid input",
),
# Test case 4: DEFER - Greedy socket ready but optional missing
(
{
"instance": "mock_instance",
"visits": 0,
"input_sockets": {
"greedy_input": InputSocket("greedy_input", GreedyVariadic[int], senders=["component1"]),
"optional_input": InputSocket(
"optional_input", str, senders=["component2"], default_value="test"
),
},
},
{"greedy_input": [{"sender": "component1", "value": 42}]},
ComponentPriority.DEFER,
"Component should DEFER when greedy socket has valid input but expected optional input is missing",
),
# Test case 4: READY - All predecessors executed
(
{
Expand Down

0 comments on commit 6b25825

Please sign in to comment.