From dbb8d57670dcc397bcb087bd3cc023e8e371d6ad Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Wed, 27 Nov 2024 02:58:54 -0800 Subject: [PATCH 1/2] Propagate team cancellation token in agentchat --- .../teams/_group_chat/_base_group_chat.py | 7 +- .../_group_chat/_base_group_chat_manager.py | 28 +++++-- .../_group_chat/_chat_agent_container.py | 1 + .../_magentic_one/_magentic_one_group_chat.py | 1 - .../_magentic_one_orchestrator.py | 45 ++++++----- .../tests/test_group_chat.py | 38 ++++++++- .../tests/test_magentic_one_group_chat.py | 80 +++++++++++++++++++ 7 files changed, 171 insertions(+), 29 deletions(-) create mode 100644 python/packages/autogen-agentchat/tests/test_magentic_one_group_chat.py diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index b038d8573059..53867c39710b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -288,12 +288,17 @@ async def stop_runtime() -> None: await self._runtime.send_message( GroupChatStart(message=first_chat_message), recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id), + cancellation_token=cancellation_token, ) # Collect the output messages in order. output_messages: List[AgentMessage] = [] # Yield the messsages until the queue is empty. while True: - message = await self._output_message_queue.get() + message_future = asyncio.ensure_future(self._output_message_queue.get()) + if cancellation_token is not None: + cancellation_token.link_future(message_future) + # Wait for the next message, this will raise an exception if the task is cancelled. + message = await message_future if message is None: break yield message diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index d2a2b917690b..201db26bfbac 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -1,3 +1,4 @@ +import asyncio from abc import ABC, abstractmethod from typing import Any, List @@ -78,7 +79,9 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type)) # Relay the start message to the participants. - await self.publish_message(message, topic_id=DefaultTopicId(type=self._group_topic_type)) + await self.publish_message( + message, topic_id=DefaultTopicId(type=self._group_topic_type), cancellation_token=ctx.cancellation_token + ) # Append the user message to the message thread. self._message_thread.append(message.message) @@ -95,8 +98,16 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No await self._termination_condition.reset() return - speaker_topic_type = await self.select_speaker(self._message_thread) - await self.publish_message(GroupChatRequestPublish(), topic_id=DefaultTopicId(type=speaker_topic_type)) + # Select a speaker to start the conversation. + speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread)) + # Link the select speaker future to the cancellation token. + ctx.cancellation_token.link_future(speaker_topic_type_future) + speaker_topic_type = await speaker_topic_type_future + await self.publish_message( + GroupChatRequestPublish(), + topic_id=DefaultTopicId(type=speaker_topic_type), + cancellation_token=ctx.cancellation_token, + ) @event async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None: @@ -140,8 +151,15 @@ async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: Mess return # Select a speaker to continue the conversation. - speaker_topic_type = await self.select_speaker(self._message_thread) - await self.publish_message(GroupChatRequestPublish(), topic_id=DefaultTopicId(type=speaker_topic_type)) + speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread)) + # Link the select speaker future to the cancellation token. + ctx.cancellation_token.link_future(speaker_topic_type_future) + speaker_topic_type = await speaker_topic_type_future + await self.publish_message( + GroupChatRequestPublish(), + topic_id=DefaultTopicId(type=speaker_topic_type), + cancellation_token=ctx.cancellation_token, + ) @rpc async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py index 315708032865..17c9830086bd 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py @@ -71,6 +71,7 @@ async def handle_request(self, message: GroupChatRequestPublish, ctx: MessageCon await self.publish_message( GroupChatAgentResponse(agent_response=response), topic_id=DefaultTopicId(type=self._parent_topic_type), + cancellation_token=ctx.cancellation_token, ) async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_group_chat.py index cd67ced11e55..d199cbfd7125 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_group_chat.py @@ -47,7 +47,6 @@ def _create_group_chat_manager_factory( return lambda: MagenticOneOrchestrator( group_topic_type, output_topic_type, - self._team_id, participant_topic_types, participant_descriptions, max_turns, diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py index f69630162340..e1f80a098055 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py @@ -1,7 +1,7 @@ import json from typing import Any, List -from autogen_core.base import MessageContext, AgentId +from autogen_core.base import AgentId, CancellationToken, MessageContext from autogen_core.components import DefaultTopicId, Image, event, rpc from autogen_core.components.models import ( AssistantMessage, @@ -42,7 +42,6 @@ def __init__( self, group_topic_type: str, output_topic_type: str, - team_id: str, participant_topic_types: List[str], participant_descriptions: List[str], max_turns: int | None, @@ -52,7 +51,6 @@ def __init__( super().__init__(description="Group chat manager") self._group_topic_type = group_topic_type self._output_topic_type = output_topic_type - self._team_id = team_id if len(participant_topic_types) != len(participant_descriptions): raise ValueError("The number of participant topic types, agent types, and descriptions must be the same.") if len(set(participant_topic_types)) != len(participant_topic_types): @@ -122,7 +120,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No planning_conversation.append( UserMessage(content=self._get_task_ledger_facts_prompt(self._task), source=self._name) ) - response = await self._model_client.create(planning_conversation) + response = await self._model_client.create(planning_conversation, cancellation_token=ctx.cancellation_token) assert isinstance(response.content, str) self._facts = response.content @@ -133,19 +131,19 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No planning_conversation.append( UserMessage(content=self._get_task_ledger_plan_prompt(self._team_description), source=self._name) ) - response = await self._model_client.create(planning_conversation) + response = await self._model_client.create(planning_conversation, cancellation_token=ctx.cancellation_token) assert isinstance(response.content, str) self._plan = response.content # Kick things off self._n_stalls = 0 - await self._reenter_inner_loop() + await self._reenter_inner_loop(ctx.cancellation_token) @event async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None: self._message_thread.append(message.agent_response.chat_message) - await self._orchestrate_step() + await self._orchestrate_step(ctx.cancellation_token) @rpc async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None: @@ -164,12 +162,13 @@ async def reset(self) -> None: async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: raise ValueError(f"Unhandled message in group chat manager: {type(message)}") - async def _reenter_inner_loop(self) -> None: + async def _reenter_inner_loop(self, cancellation_token: CancellationToken) -> None: # Reset the agents for participant_topic_type in self._participant_topic_types: await self._runtime.send_message( GroupChatReset(), - recipient=AgentId(type=participant_topic_type, key=self._team_id), + recipient=AgentId(type=participant_topic_type, key=self.id.key), + cancellation_token=cancellation_token, ) # Reset the group chat manager await self.reset() @@ -197,12 +196,12 @@ async def _reenter_inner_loop(self) -> None: ) # Restart the inner loop - await self._orchestrate_step() + await self._orchestrate_step(cancellation_token=cancellation_token) - async def _orchestrate_step(self) -> None: + async def _orchestrate_step(self, cancellation_token: CancellationToken) -> None: # Check if we reached the maximum number of rounds if self._max_turns is not None and self._n_rounds > self._max_turns: - await self._prepare_final_answer("Max rounds reached.") + await self._prepare_final_answer("Max rounds reached.", cancellation_token) return self._n_rounds += 1 @@ -221,7 +220,7 @@ async def _orchestrate_step(self) -> None: # Check for task completion if progress_ledger["is_request_satisfied"]["answer"]: - await self._prepare_final_answer(progress_ledger["is_request_satisfied"]["reason"]) + await self._prepare_final_answer(progress_ledger["is_request_satisfied"]["reason"], cancellation_token) return # Check for stalling @@ -234,8 +233,8 @@ async def _orchestrate_step(self) -> None: # Too much stalling if self._n_stalls >= self._max_stalls: - await self._update_task_ledger() - await self._reenter_inner_loop() + await self._update_task_ledger(cancellation_token) + await self._reenter_inner_loop(cancellation_token) return # Broadcst the next step @@ -252,20 +251,23 @@ async def _orchestrate_step(self) -> None: await self.publish_message( # Broadcast GroupChatAgentResponse(agent_response=Response(chat_message=message)), topic_id=DefaultTopicId(type=self._group_topic_type), + cancellation_token=cancellation_token, ) # Request that the step be completed next_speaker = progress_ledger["next_speaker"]["answer"] - await self.publish_message(GroupChatRequestPublish(), topic_id=DefaultTopicId(type=next_speaker)) + await self.publish_message( + GroupChatRequestPublish(), topic_id=DefaultTopicId(type=next_speaker), cancellation_token=cancellation_token + ) - async def _update_task_ledger(self) -> None: + async def _update_task_ledger(self, cancellation_token: CancellationToken) -> None: context = self._thread_to_context() # Update the facts update_facts_prompt = self._get_task_ledger_facts_update_prompt(self._task, self._facts) context.append(UserMessage(content=update_facts_prompt, source=self._name)) - response = await self._model_client.create(context) + response = await self._model_client.create(context, cancellation_token=cancellation_token) assert isinstance(response.content, str) self._facts = response.content @@ -275,19 +277,19 @@ async def _update_task_ledger(self) -> None: update_plan_prompt = self._get_task_ledger_plan_update_prompt(self._team_description) context.append(UserMessage(content=update_plan_prompt, source=self._name)) - response = await self._model_client.create(context) + response = await self._model_client.create(context, cancellation_token=cancellation_token) assert isinstance(response.content, str) self._plan = response.content - async def _prepare_final_answer(self, reason: str) -> None: + async def _prepare_final_answer(self, reason: str, cancellation_token: CancellationToken) -> None: context = self._thread_to_context() # Get the final answer final_answer_prompt = self._get_final_answer_prompt(self._task) context.append(UserMessage(content=final_answer_prompt, source=self._name)) - response = await self._model_client.create(context) + response = await self._model_client.create(context, cancellation_token=cancellation_token) assert isinstance(response.content, str) message = TextMessage(content=response.content, source=self._name) @@ -303,6 +305,7 @@ async def _prepare_final_answer(self, reason: str) -> None: await self.publish_message( GroupChatAgentResponse(agent_response=Response(chat_message=message)), topic_id=DefaultTopicId(type=self._group_topic_type), + cancellation_token=cancellation_token, ) # Signal termination diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 7df27abcbcd5..00ac5ee90fec 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -69,18 +69,25 @@ class _EchoAgent(BaseChatAgent): def __init__(self, name: str, description: str) -> None: super().__init__(name, description) self._last_message: str | None = None + self._total_messages = 0 @property def produced_message_types(self) -> List[type[ChatMessage]]: return [TextMessage] + @property + def total_messages(self) -> int: + return self._total_messages + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: if len(messages) > 0: assert isinstance(messages[0], TextMessage) self._last_message = messages[0].content + self._total_messages += 1 return Response(chat_message=TextMessage(content=messages[0].content, source=self.name)) else: assert self._last_message is not None + self._total_messages += 1 return Response(chat_message=TextMessage(content=self._last_message, source=self.name)) async def on_reset(self, cancellation_token: CancellationToken) -> None: @@ -358,7 +365,7 @@ async def test_round_robin_group_chat_with_resume_and_reset() -> None: @pytest.mark.asyncio -async def test_round_group_chat_max_turn() -> None: +async def test_round_robin_group_chat_max_turn() -> None: agent_1 = _EchoAgent("agent_1", description="echo agent 1") agent_2 = _EchoAgent("agent_2", description="echo agent 2") agent_3 = _EchoAgent("agent_3", description="echo agent 3") @@ -391,6 +398,35 @@ async def test_round_group_chat_max_turn() -> None: assert result.stop_reason is not None +@pytest.mark.asyncio +async def test_round_robin_group_chat_cancellation() -> None: + agent_1 = _EchoAgent("agent_1", description="echo agent 1") + agent_2 = _EchoAgent("agent_2", description="echo agent 2") + agent_3 = _EchoAgent("agent_3", description="echo agent 3") + agent_4 = _EchoAgent("agent_4", description="echo agent 4") + # Set max_turns to a large number to avoid stopping due to max_turns before cancellation. + team = RoundRobinGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], max_turns=1000) + cancellation_token = CancellationToken() + run_task = asyncio.create_task( + team.run( + task="Write a program that prints 'Hello, world!'", + cancellation_token=cancellation_token, + ) + ) + await asyncio.sleep(0.1) + # Cancel the task. + cancellation_token.cancel() + with pytest.raises(asyncio.CancelledError): + await run_task + + # Total messages produced so far. + total_messages = agent_1.total_messages + agent_2.total_messages + agent_3.total_messages + agent_4.total_messages + + # Still can run again and finish the task. + result = await team.run() + assert len(result.messages) + total_messages == 1000 + + @pytest.mark.asyncio async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None: model = "gpt-4o-2024-05-13" diff --git a/python/packages/autogen-agentchat/tests/test_magentic_one_group_chat.py b/python/packages/autogen-agentchat/tests/test_magentic_one_group_chat.py new file mode 100644 index 000000000000..2a8931a93d66 --- /dev/null +++ b/python/packages/autogen-agentchat/tests/test_magentic_one_group_chat.py @@ -0,0 +1,80 @@ +import asyncio +import json +import logging +from typing import List, Sequence + +import pytest +from autogen_agentchat import EVENT_LOGGER_NAME +from autogen_agentchat.agents import ( + BaseChatAgent, +) +from autogen_agentchat.base import Response +from autogen_agentchat.logging import FileLogHandler +from autogen_agentchat.messages import ( + ChatMessage, + TextMessage, +) +from autogen_agentchat.teams import ( + MagenticOneGroupChat, +) +from autogen_core.base import CancellationToken +from autogen_ext.models import ReplayChatCompletionClient + +logger = logging.getLogger(EVENT_LOGGER_NAME) +logger.setLevel(logging.DEBUG) +logger.addHandler(FileLogHandler("test_magentic_one_group_chat.log")) + + +class _EchoAgent(BaseChatAgent): + def __init__(self, name: str, description: str) -> None: + super().__init__(name, description) + self._last_message: str | None = None + self._total_messages = 0 + + @property + def produced_message_types(self) -> List[type[ChatMessage]]: + return [TextMessage] + + @property + def total_messages(self) -> int: + return self._total_messages + + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: + if len(messages) > 0: + assert isinstance(messages[0], TextMessage) + self._last_message = messages[0].content + self._total_messages += 1 + return Response(chat_message=TextMessage(content=messages[0].content, source=self.name)) + else: + assert self._last_message is not None + self._total_messages += 1 + return Response(chat_message=TextMessage(content=self._last_message, source=self.name)) + + async def on_reset(self, cancellation_token: CancellationToken) -> None: + self._last_message = None + + +@pytest.mark.asyncio +async def test_magentic_one_group_chat_cancellation() -> None: + agent_1 = _EchoAgent("agent_1", description="echo agent 1") + agent_2 = _EchoAgent("agent_2", description="echo agent 2") + agent_3 = _EchoAgent("agent_3", description="echo agent 3") + agent_4 = _EchoAgent("agent_4", description="echo agent 4") + + model_client = ReplayChatCompletionClient( + chat_completions=["test", "test", json.dumps({"is_request_satisfied": {"answer": True, "reason": "test"}})], + ) + + # Set max_turns to a large number to avoid stopping due to max_turns before cancellation. + team = MagenticOneGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client) + cancellation_token = CancellationToken() + run_task = asyncio.create_task( + team.run( + task="Write a program that prints 'Hello, world!'", + cancellation_token=cancellation_token, + ) + ) + # Cancel the task. + cancellation_token.cancel() + with pytest.raises(asyncio.CancelledError): + await run_task From 40b962f90b7fcd6022e04cda708f8fd2fd7aa62e Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Wed, 27 Nov 2024 03:20:11 -0800 Subject: [PATCH 2/2] Docs --- .../teams/_group_chat/_base_group_chat.py | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 53867c39710b..0990ddc6652b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -170,6 +170,13 @@ async def run( :meth:`run_stream` to run the team and then returns the final result. Once the team is stopped, the termination condition is reset. + Args: + task (str | ChatMessage | None): The task to run the team with. + cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately. + Setting the cancellation token potentially put the team in an inconsistent state, + and it may not reset the termination condition. + To gracefully stop the team, use :class:`~autogen_agentchat.task.ExternalTermination` instead. + Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team: @@ -198,6 +205,47 @@ async def main() -> None: print(result) + asyncio.run(main()) + + + Example using the :class:`~autogen_core.base.CancellationToken` to cancel the task: + + .. code-block:: python + + import asyncio + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.task import MaxMessageTermination + from autogen_agentchat.teams import RoundRobinGroupChat + from autogen_core.base import CancellationToken + from autogen_ext.models import OpenAIChatCompletionClient + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent1 = AssistantAgent("Assistant1", model_client=model_client) + agent2 = AssistantAgent("Assistant2", model_client=model_client) + termination = MaxMessageTermination(3) + team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) + + cancellation_token = CancellationToken() + + # Create a task to run the team in the background. + run_task = asyncio.create_task( + team.run( + task="Count from 1 to 10, respond one at a time.", + cancellation_token=cancellation_token, + ) + ) + + # Wait for 1 second and then cancel the task. + await asyncio.sleep(1) + cancellation_token.cancel() + + # This will raise a cancellation error. + await run_task + + asyncio.run(main()) """ result: TaskResult | None = None @@ -221,6 +269,13 @@ async def run_stream( of the type :class:`TaskResult` as the last item in the stream. Once the team is stopped, the termination condition is reset. + Args: + task (str | ChatMessage | None): The task to run the team with. + cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately. + Setting the cancellation token potentially put the team in an inconsistent state, + and it may not reset the termination condition. + To gracefully stop the team, use :class:`~autogen_agentchat.task.ExternalTermination` instead. + Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team: .. code-block:: python @@ -251,7 +306,52 @@ async def main() -> None: asyncio.run(main()) + + + Example using the :class:`~autogen_core.base.CancellationToken` to cancel the task: + + .. code-block:: python + + import asyncio + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.task import MaxMessageTermination, Console + from autogen_agentchat.teams import RoundRobinGroupChat + from autogen_core.base import CancellationToken + from autogen_ext.models import OpenAIChatCompletionClient + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent1 = AssistantAgent("Assistant1", model_client=model_client) + agent2 = AssistantAgent("Assistant2", model_client=model_client) + termination = MaxMessageTermination(3) + team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) + + cancellation_token = CancellationToken() + + # Create a task to run the team in the background. + run_task = asyncio.create_task( + Console( + team.run_stream( + task="Count from 1 to 10, respond one at a time.", + cancellation_token=cancellation_token, + ) + ) + ) + + # Wait for 1 second and then cancel the task. + await asyncio.sleep(1) + cancellation_token.cancel() + + # This will raise a cancellation error. + await run_task + + + asyncio.run(main()) + """ + # Create the first chat message if the task is a string or a chat message. first_chat_message: ChatMessage | None = None if task is None: