Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate team cancellation token in agentchat #4400

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ async def run(
:meth:`run_stream` to run the team and then returns the final result.
Once the team is stopped, the termination condition is reset.

Args:
task (str | ChatMessage | None): The task to run the team with.
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
Setting the cancellation token potentially put the team in an inconsistent state,
and it may not reset the termination condition.
To gracefully stop the team, use :class:`~autogen_agentchat.task.ExternalTermination` instead.

Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:


Expand Down Expand Up @@ -198,6 +205,47 @@ async def main() -> None:
print(result)


asyncio.run(main())


Example using the :class:`~autogen_core.base.CancellationToken` to cancel the task:

.. code-block:: python

import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.task import MaxMessageTermination
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_core.base import CancellationToken
from autogen_ext.models import OpenAIChatCompletionClient


async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")

agent1 = AssistantAgent("Assistant1", model_client=model_client)
agent2 = AssistantAgent("Assistant2", model_client=model_client)
termination = MaxMessageTermination(3)
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)

cancellation_token = CancellationToken()

# Create a task to run the team in the background.
run_task = asyncio.create_task(
team.run(
task="Count from 1 to 10, respond one at a time.",
cancellation_token=cancellation_token,
)
)

# Wait for 1 second and then cancel the task.
await asyncio.sleep(1)
cancellation_token.cancel()

# This will raise a cancellation error.
await run_task


asyncio.run(main())
"""
result: TaskResult | None = None
Expand All @@ -221,6 +269,13 @@ async def run_stream(
of the type :class:`TaskResult` as the last item in the stream. Once the
team is stopped, the termination condition is reset.

Args:
task (str | ChatMessage | None): The task to run the team with.
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
Setting the cancellation token potentially put the team in an inconsistent state,
and it may not reset the termination condition.
To gracefully stop the team, use :class:`~autogen_agentchat.task.ExternalTermination` instead.

Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:

.. code-block:: python
Expand Down Expand Up @@ -251,7 +306,52 @@ async def main() -> None:


asyncio.run(main())


Example using the :class:`~autogen_core.base.CancellationToken` to cancel the task:

.. code-block:: python

import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.task import MaxMessageTermination, Console
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_core.base import CancellationToken
from autogen_ext.models import OpenAIChatCompletionClient


async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")

agent1 = AssistantAgent("Assistant1", model_client=model_client)
agent2 = AssistantAgent("Assistant2", model_client=model_client)
termination = MaxMessageTermination(3)
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)

cancellation_token = CancellationToken()

# Create a task to run the team in the background.
run_task = asyncio.create_task(
Console(
team.run_stream(
task="Count from 1 to 10, respond one at a time.",
cancellation_token=cancellation_token,
)
)
)

# Wait for 1 second and then cancel the task.
await asyncio.sleep(1)
cancellation_token.cancel()

# This will raise a cancellation error.
await run_task


asyncio.run(main())

"""

# Create the first chat message if the task is a string or a chat message.
first_chat_message: ChatMessage | None = None
if task is None:
Expand Down Expand Up @@ -288,12 +388,17 @@ async def stop_runtime() -> None:
await self._runtime.send_message(
GroupChatStart(message=first_chat_message),
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
cancellation_token=cancellation_token,
)
# Collect the output messages in order.
output_messages: List[AgentMessage] = []
# Yield the messsages until the queue is empty.
while True:
message = await self._output_message_queue.get()
message_future = asyncio.ensure_future(self._output_message_queue.get())
if cancellation_token is not None:
cancellation_token.link_future(message_future)
# Wait for the next message, this will raise an exception if the task is cancelled.
message = await message_future
if message is None:
break
yield message
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from abc import ABC, abstractmethod
from typing import Any, List

Expand Down Expand Up @@ -78,7 +79,9 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type))

# Relay the start message to the participants.
await self.publish_message(message, topic_id=DefaultTopicId(type=self._group_topic_type))
await self.publish_message(
message, topic_id=DefaultTopicId(type=self._group_topic_type), cancellation_token=ctx.cancellation_token
)

# Append the user message to the message thread.
self._message_thread.append(message.message)
Expand All @@ -95,8 +98,16 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
await self._termination_condition.reset()
return

speaker_topic_type = await self.select_speaker(self._message_thread)
await self.publish_message(GroupChatRequestPublish(), topic_id=DefaultTopicId(type=speaker_topic_type))
# Select a speaker to start the conversation.
speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
ctx.cancellation_token.link_future(speaker_topic_type_future)
speaker_topic_type = await speaker_topic_type_future
await self.publish_message(
GroupChatRequestPublish(),
topic_id=DefaultTopicId(type=speaker_topic_type),
cancellation_token=ctx.cancellation_token,
)

@event
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
Expand Down Expand Up @@ -140,8 +151,15 @@ async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: Mess
return

# Select a speaker to continue the conversation.
speaker_topic_type = await self.select_speaker(self._message_thread)
await self.publish_message(GroupChatRequestPublish(), topic_id=DefaultTopicId(type=speaker_topic_type))
speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
ctx.cancellation_token.link_future(speaker_topic_type_future)
speaker_topic_type = await speaker_topic_type_future
await self.publish_message(
GroupChatRequestPublish(),
topic_id=DefaultTopicId(type=speaker_topic_type),
cancellation_token=ctx.cancellation_token,
)

@rpc
async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ async def handle_request(self, message: GroupChatRequestPublish, ctx: MessageCon
await self.publish_message(
GroupChatAgentResponse(agent_response=response),
topic_id=DefaultTopicId(type=self._parent_topic_type),
cancellation_token=ctx.cancellation_token,
)

async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def _create_group_chat_manager_factory(
return lambda: MagenticOneOrchestrator(
group_topic_type,
output_topic_type,
self._team_id,
participant_topic_types,
participant_descriptions,
max_turns,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import Any, List

from autogen_core.base import MessageContext, AgentId
from autogen_core.base import AgentId, CancellationToken, MessageContext
from autogen_core.components import DefaultTopicId, Image, event, rpc
from autogen_core.components.models import (
AssistantMessage,
Expand Down Expand Up @@ -42,7 +42,6 @@ def __init__(
self,
group_topic_type: str,
output_topic_type: str,
team_id: str,
participant_topic_types: List[str],
participant_descriptions: List[str],
max_turns: int | None,
Expand All @@ -52,7 +51,6 @@ def __init__(
super().__init__(description="Group chat manager")
self._group_topic_type = group_topic_type
self._output_topic_type = output_topic_type
self._team_id = team_id
if len(participant_topic_types) != len(participant_descriptions):
raise ValueError("The number of participant topic types, agent types, and descriptions must be the same.")
if len(set(participant_topic_types)) != len(participant_topic_types):
Expand Down Expand Up @@ -122,7 +120,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
planning_conversation.append(
UserMessage(content=self._get_task_ledger_facts_prompt(self._task), source=self._name)
)
response = await self._model_client.create(planning_conversation)
response = await self._model_client.create(planning_conversation, cancellation_token=ctx.cancellation_token)

assert isinstance(response.content, str)
self._facts = response.content
Expand All @@ -133,19 +131,19 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
planning_conversation.append(
UserMessage(content=self._get_task_ledger_plan_prompt(self._team_description), source=self._name)
)
response = await self._model_client.create(planning_conversation)
response = await self._model_client.create(planning_conversation, cancellation_token=ctx.cancellation_token)

assert isinstance(response.content, str)
self._plan = response.content

# Kick things off
self._n_stalls = 0
await self._reenter_inner_loop()
await self._reenter_inner_loop(ctx.cancellation_token)

@event
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
self._message_thread.append(message.agent_response.chat_message)
await self._orchestrate_step()
await self._orchestrate_step(ctx.cancellation_token)

@rpc
async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None:
Expand All @@ -164,12 +162,13 @@ async def reset(self) -> None:
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
raise ValueError(f"Unhandled message in group chat manager: {type(message)}")

async def _reenter_inner_loop(self) -> None:
async def _reenter_inner_loop(self, cancellation_token: CancellationToken) -> None:
# Reset the agents
for participant_topic_type in self._participant_topic_types:
await self._runtime.send_message(
GroupChatReset(),
recipient=AgentId(type=participant_topic_type, key=self._team_id),
recipient=AgentId(type=participant_topic_type, key=self.id.key),
cancellation_token=cancellation_token,
)
# Reset the group chat manager
await self.reset()
Expand Down Expand Up @@ -197,12 +196,12 @@ async def _reenter_inner_loop(self) -> None:
)

# Restart the inner loop
await self._orchestrate_step()
await self._orchestrate_step(cancellation_token=cancellation_token)

async def _orchestrate_step(self) -> None:
async def _orchestrate_step(self, cancellation_token: CancellationToken) -> None:
# Check if we reached the maximum number of rounds
if self._max_turns is not None and self._n_rounds > self._max_turns:
await self._prepare_final_answer("Max rounds reached.")
await self._prepare_final_answer("Max rounds reached.", cancellation_token)
return
self._n_rounds += 1

Expand All @@ -221,7 +220,7 @@ async def _orchestrate_step(self) -> None:

# Check for task completion
if progress_ledger["is_request_satisfied"]["answer"]:
await self._prepare_final_answer(progress_ledger["is_request_satisfied"]["reason"])
await self._prepare_final_answer(progress_ledger["is_request_satisfied"]["reason"], cancellation_token)
return

# Check for stalling
Expand All @@ -234,8 +233,8 @@ async def _orchestrate_step(self) -> None:

# Too much stalling
if self._n_stalls >= self._max_stalls:
await self._update_task_ledger()
await self._reenter_inner_loop()
await self._update_task_ledger(cancellation_token)
await self._reenter_inner_loop(cancellation_token)
return

# Broadcst the next step
Expand All @@ -252,20 +251,23 @@ async def _orchestrate_step(self) -> None:
await self.publish_message( # Broadcast
GroupChatAgentResponse(agent_response=Response(chat_message=message)),
topic_id=DefaultTopicId(type=self._group_topic_type),
cancellation_token=cancellation_token,
)

# Request that the step be completed
next_speaker = progress_ledger["next_speaker"]["answer"]
await self.publish_message(GroupChatRequestPublish(), topic_id=DefaultTopicId(type=next_speaker))
await self.publish_message(
GroupChatRequestPublish(), topic_id=DefaultTopicId(type=next_speaker), cancellation_token=cancellation_token
)

async def _update_task_ledger(self) -> None:
async def _update_task_ledger(self, cancellation_token: CancellationToken) -> None:
context = self._thread_to_context()

# Update the facts
update_facts_prompt = self._get_task_ledger_facts_update_prompt(self._task, self._facts)
context.append(UserMessage(content=update_facts_prompt, source=self._name))

response = await self._model_client.create(context)
response = await self._model_client.create(context, cancellation_token=cancellation_token)

assert isinstance(response.content, str)
self._facts = response.content
Expand All @@ -275,19 +277,19 @@ async def _update_task_ledger(self) -> None:
update_plan_prompt = self._get_task_ledger_plan_update_prompt(self._team_description)
context.append(UserMessage(content=update_plan_prompt, source=self._name))

response = await self._model_client.create(context)
response = await self._model_client.create(context, cancellation_token=cancellation_token)

assert isinstance(response.content, str)
self._plan = response.content

async def _prepare_final_answer(self, reason: str) -> None:
async def _prepare_final_answer(self, reason: str, cancellation_token: CancellationToken) -> None:
context = self._thread_to_context()

# Get the final answer
final_answer_prompt = self._get_final_answer_prompt(self._task)
context.append(UserMessage(content=final_answer_prompt, source=self._name))

response = await self._model_client.create(context)
response = await self._model_client.create(context, cancellation_token=cancellation_token)
assert isinstance(response.content, str)
message = TextMessage(content=response.content, source=self._name)

Expand All @@ -303,6 +305,7 @@ async def _prepare_final_answer(self, reason: str) -> None:
await self.publish_message(
GroupChatAgentResponse(agent_response=Response(chat_message=message)),
topic_id=DefaultTopicId(type=self._group_topic_type),
cancellation_token=cancellation_token,
)

# Signal termination
Expand Down
Loading
Loading