Skip to content

Commit

Permalink
feat: support streaming_callback as run param for HF Chat generators
Browse files Browse the repository at this point in the history
  • Loading branch information
tstadel committed Jan 22, 2025
1 parent f96839e commit 91b752b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
18 changes: 14 additions & 4 deletions haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def run(
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Invoke the text generation inference based on the provided messages and generation parameters.
Expand All @@ -231,6 +232,9 @@ def run(
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
during component initialization.
:param streaming_callback:
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
parameter set during component initialization.
:returns: A dictionary with the following keys:
- `replies`: A list containing the generated responses as ChatMessage objects.
"""
Expand All @@ -245,16 +249,22 @@ def run(
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)

if self.streaming_callback:
return self._run_streaming(formatted_messages, generation_kwargs)
streaming_callback = streaming_callback or self.streaming_callback
if streaming_callback:
return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)

hf_tools = None
if tools:
hf_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools]

return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)

def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]):
def _run_streaming(
self,
messages: List[Dict[str, str]],
generation_kwargs: Dict[str, Any],
streaming_callback: Callable[[StreamingChunk], None],
):
api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
messages, stream=True, **generation_kwargs
)
Expand Down Expand Up @@ -282,7 +292,7 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict
first_chunk_time = datetime.now().isoformat()

stream_chunk = StreamingChunk(text, meta)
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
streaming_callback(stream_chunk)

meta.update(
{
Expand Down
13 changes: 10 additions & 3 deletions haystack/components/generators/chat/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,18 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalChatGenerator":
return default_from_dict(cls, data)

@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
def run(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Invoke text generation inference based on the provided messages and generation parameters.
:param messages: A list of ChatMessage objects representing the input messages.
:param generation_kwargs: Additional keyword arguments for text generation.
:param streaming_callback: An optional callable for handling streaming responses.
:returns:
A list containing the generated responses as ChatMessage instances.
"""
Expand All @@ -259,7 +265,8 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
if stop_words_criteria:
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])

if self.streaming_callback:
streaming_callback = streaming_callback or self.streaming_callback
if streaming_callback:
num_responses = generation_kwargs.get("num_return_sequences", 1)
if num_responses > 1:
msg = (
Expand All @@ -270,7 +277,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
logger.warning(msg, num_responses=num_responses)
generation_kwargs["num_return_sequences"] = 1
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words)
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)

hf_messages = [convert_message_to_hf_format(message) for message in messages]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Streaming callback run param support for HF chat generators.

0 comments on commit 91b752b

Please sign in to comment.