diff --git a/haystack/components/routers/conditional_router.py b/haystack/components/routers/conditional_router.py index a36b4355ee..a2ca1c0d76 100644 --- a/haystack/components/routers/conditional_router.py +++ b/haystack/components/routers/conditional_router.py @@ -8,7 +8,7 @@ from warnings import warn from jinja2 import Environment, TemplateSyntaxError, meta -from jinja2.nativetypes import NativeEnvironment, NativeTemplate +from jinja2.nativetypes import NativeEnvironment, NativeTemplate, Template from jinja2.sandbox import SandboxedEnvironment from haystack import component, default_from_dict, default_to_dict, logging @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) -haystack_dataclass_types = (ByteStream, Document, ChatMessage, Answer, SparseEmbedding, StreamingChunk) +haystack_dataclass_types = (ByteStream, ChatMessage, Document, Answer, SparseEmbedding, StreamingChunk) class NoRouteSelectedException(Exception): @@ -28,7 +28,7 @@ class RouteConditionException(Exception): """Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter.""" -class NativeSandboxedTemplate(NativeTemplate): +class NativeSandboxedTemplate(NativeTemplate, Template): """ A template class that returns native Python objects and also respects the sandbox security checks. """ @@ -42,25 +42,25 @@ class NativeSandboxedEnvironment(SandboxedEnvironment, NativeEnvironment): """ # We tell the environment to use our custom template class by default. - template_class = NativeSandboxedTemplate - def from_string(self, source, template_class=None): + def from_string(self, source): """ Override from_string to ensure the sandbox logic + native logic are used together. """ - if template_class is None: - template_class = self.template_class + template_class = NativeSandboxedTemplate + return SandboxedEnvironment.from_string(self, source, template_class=template_class) - def is_safe_attribute(self, obj): + def is_safe_attribute(self, obj, attr="", value=""): """ Whitelist Haystack dataclasses so the sandbox won't block them. """ - if isinstance(obj, haystack_dataclass_types): - return True + + if not isinstance(obj, haystack_dataclass_types): + return False # Otherwise, fallback to the default sandbox behavior - return super().is_safe_attribute(obj) + return SandboxedEnvironment.is_safe_attribute(self, obj, attr, value) @component @@ -236,14 +236,6 @@ def __init__( # pylint: disable=too-many-positional-arguments self._custom_env = NativeSandboxedEnvironment() self._env.filters.update(self.custom_filters) - # Add custom types to the custom environment - self._custom_env.globals["Document"] = Document - self._custom_env.globals["ChatMessage"] = ChatMessage - self._custom_env.globals["ByteStream"] = ByteStream - self._custom_env.globals["Answer"] = Answer - self._custom_env.globals["SparseEmbedding"] = SparseEmbedding - self._custom_env.globals["StreamingChunk"] = StreamingChunk - self._validate_routes(routes) # Inspect the routes to determine input and output types. input_types: Set[str] = set() # let's just store the name, type will always be Any @@ -359,13 +351,21 @@ def run(self, **kwargs): t_output = self._custom_env.from_string(route["output"]) output = t_output.render(**kwargs) + # Check if output is a list/sequence and validate accordingly + if isinstance(output, (list, tuple)): + if all(self._custom_env.is_safe_attribute(item) for item in output): + pass + elif self._custom_env.is_safe_attribute(output): + pass + # We suppress the exception in case the output is already a string, otherwise # we try to evaluate it and would fail. # This must be done cause the output could be different literal structures. # This doesn't support any user types. - with contextlib.suppress(Exception): - if not self._unsafe and isinstance(output, str): - output = ast.literal_eval(output) + else: + with contextlib.suppress(Exception): + if not self._unsafe and isinstance(output, str): + output = ast.literal_eval(output) except Exception as e: msg = f"Error evaluating condition for route '{route}': {e}" raise RouteConditionException(msg) from e