diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index 1264272fca..9fb5bda1f3 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -220,6 +220,7 @@ def run( messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None, tools: Optional[List[Tool]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ Invoke the text generation inference based on the provided messages and generation parameters. @@ -231,6 +232,9 @@ def run( :param tools: A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set during component initialization. + :param streaming_callback: + An optional callable for handling streaming responses. If set, it will override the `streaming_callback` + parameter set during component initialization. :returns: A dictionary with the following keys: - `replies`: A list containing the generated responses as ChatMessage objects. """ @@ -245,8 +249,9 @@ def run( raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") _check_duplicate_tool_names(tools) - if self.streaming_callback: - return self._run_streaming(formatted_messages, generation_kwargs) + streaming_callback = streaming_callback or self.streaming_callback + if streaming_callback: + return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback) hf_tools = None if tools: @@ -254,7 +259,12 @@ def run( return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools) - def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]): + def _run_streaming( + self, + messages: List[Dict[str, str]], + generation_kwargs: Dict[str, Any], + streaming_callback: Callable[[StreamingChunk], None], + ): api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion( messages, stream=True, **generation_kwargs ) @@ -282,7 +292,7 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict first_chunk_time = datetime.now().isoformat() stream_chunk = StreamingChunk(text, meta) - self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method) + streaming_callback(stream_chunk) meta.update( { diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index a79a6dcfa8..d5d05ae487 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -233,12 +233,18 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalChatGenerator": return default_from_dict(cls, data) @component.output_types(replies=List[ChatMessage]) - def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + def run( + self, + messages: List[ChatMessage], + generation_kwargs: Optional[Dict[str, Any]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + ): """ Invoke text generation inference based on the provided messages and generation parameters. :param messages: A list of ChatMessage objects representing the input messages. :param generation_kwargs: Additional keyword arguments for text generation. + :param streaming_callback: An optional callable for handling streaming responses. :returns: A list containing the generated responses as ChatMessage instances. """ @@ -259,7 +265,8 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, if stop_words_criteria: generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria]) - if self.streaming_callback: + streaming_callback = streaming_callback or self.streaming_callback + if streaming_callback: num_responses = generation_kwargs.get("num_return_sequences", 1) if num_responses > 1: msg = ( @@ -270,7 +277,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, logger.warning(msg, num_responses=num_responses) generation_kwargs["num_return_sequences"] = 1 # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming - generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words) + generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words) hf_messages = [convert_message_to_hf_format(message) for message in messages] diff --git a/releasenotes/notes/streaming-callback-run-param-support-for-hf-chat-generators-68aaa7e540ad03ce.yaml b/releasenotes/notes/streaming-callback-run-param-support-for-hf-chat-generators-68aaa7e540ad03ce.yaml new file mode 100644 index 0000000000..8d5e6007f5 --- /dev/null +++ b/releasenotes/notes/streaming-callback-run-param-support-for-hf-chat-generators-68aaa7e540ad03ce.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Streaming callback run param support for HF chat generators.