Skip to content

Commit

Permalink
Merge branch 'main' into add-dist-group-chat-dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
MohMaz authored Jan 24, 2025
2 parents ee24828 + 979d8ab commit 2814243
Show file tree
Hide file tree
Showing 13 changed files with 509 additions and 180 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Any, AsyncGenerator, List, Mapping, Sequence

from autogen_core import CancellationToken
from autogen_core import CancellationToken, Component, ComponentModel
from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
from pydantic import BaseModel
from typing_extensions import Self

from autogen_agentchat.base import Response
from autogen_agentchat.state import SocietyOfMindAgentState
Expand All @@ -16,7 +18,18 @@
from ._base_chat_agent import BaseChatAgent


class SocietyOfMindAgent(BaseChatAgent):
class SocietyOfMindAgentConfig(BaseModel):
"""The declarative configuration for a SocietyOfMindAgent."""

name: str
team: ComponentModel
model_client: ComponentModel
description: str
instruction: str
response_prompt: str


class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
"""An agent that uses an inner team of agents to generate responses.
Each time the agent's :meth:`on_messages` or :meth:`on_messages_stream`
Expand Down Expand Up @@ -74,6 +87,9 @@ async def main() -> None:
asyncio.run(main())
"""

component_config_schema = SocietyOfMindAgentConfig
component_provider_override = "autogen_agentchat.agents.SocietyOfMindAgent"

DEFAULT_INSTRUCTION = "Earlier you were asked to fulfill a request. You and your team worked diligently to address that request. Here is a transcript of that conversation:"
"""str: The default instruction to use when generating a response using the
inner team's messages. The instruction will be prepended to the inner team's
Expand Down Expand Up @@ -173,3 +189,26 @@ async def save_state(self) -> Mapping[str, Any]:
async def load_state(self, state: Mapping[str, Any]) -> None:
society_of_mind_state = SocietyOfMindAgentState.model_validate(state)
await self._team.load_state(society_of_mind_state.inner_team_state)

def _to_config(self) -> SocietyOfMindAgentConfig:
return SocietyOfMindAgentConfig(
name=self.name,
team=self._team.dump_component(),
model_client=self._model_client.dump_component(),
description=self.description,
instruction=self._instruction,
response_prompt=self._response_prompt,
)

@classmethod
def _from_config(cls, config: SocietyOfMindAgentConfig) -> Self:
model_client = ChatCompletionClient.load_component(config.model_client)
team = Team.load_component(config.team)
return cls(
name=config.name,
team=team,
model_client=model_client,
description=config.description,
instruction=config.instruction,
response_prompt=config.response_prompt,
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Mapping, Sequence

from autogen_core import CancellationToken
from autogen_core import CancellationToken, ComponentBase
from pydantic import BaseModel

from ..messages import AgentEvent, ChatMessage
from ._task import TaskRunner
Expand All @@ -20,9 +21,11 @@ class Response:
or :class:`ChatMessage`."""


class ChatAgent(ABC, TaskRunner):
class ChatAgent(ABC, TaskRunner, ComponentBase[BaseModel]):
"""Protocol for a chat agent."""

component_type = "agent"

@property
@abstractmethod
def name(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from typing import Any, Mapping
from abc import ABC, abstractmethod
from typing import Any, Mapping

from autogen_core import ComponentBase
from pydantic import BaseModel

from ._task import TaskRunner


class Team(ABC, TaskRunner):
class Team(ABC, TaskRunner, ComponentBase[BaseModel]):
component_type = "team"

@abstractmethod
async def reset(self) -> None:
"""Reset the team and all its participants to its initial state."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
AgentType,
CancellationToken,
ClosureAgent,
ComponentBase,
MessageContext,
SingleThreadedAgentRuntime,
TypeSubscription,
)
from autogen_core._closure_agent import ClosureContext
from pydantic import BaseModel

from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
Expand All @@ -28,13 +30,15 @@
event_logger = logging.getLogger(EVENT_LOGGER_NAME)


class BaseGroupChat(Team, ABC):
class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
"""The base class for group chat teams.
To implement a group chat team, first create a subclass of :class:`BaseGroupChatManager` and then
create a subclass of :class:`BaseGroupChat` that uses the group chat manager.
"""

component_type = "team"

def __init__(
self,
participants: List[ChatAgent],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import logging
from typing import Callable, List

from autogen_core import Component, ComponentModel
from autogen_core.models import ChatCompletionClient
from pydantic import BaseModel
from typing_extensions import Self

from .... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
from ....base import ChatAgent, TerminationCondition
Expand All @@ -13,7 +16,18 @@
event_logger = logging.getLogger(EVENT_LOGGER_NAME)


class MagenticOneGroupChat(BaseGroupChat):
class MagenticOneGroupChatConfig(BaseModel):
"""The declarative configuration for a MagenticOneGroupChat."""

participants: List[ComponentModel]
model_client: ComponentModel
termination_condition: ComponentModel | None = None
max_turns: int | None = None
max_stalls: int
final_answer_prompt: str


class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig]):
"""A team that runs a group chat with participants managed by the MagenticOneOrchestrator.
The orchestrator handles the conversation flow, ensuring that the task is completed
Expand Down Expand Up @@ -73,6 +87,9 @@ async def main() -> None:
}
"""

component_config_schema = MagenticOneGroupChatConfig
component_provider_override = "autogen_agentchat.teams.MagenticOneGroupChat"

def __init__(
self,
participants: List[ChatAgent],
Expand Down Expand Up @@ -117,3 +134,31 @@ def _create_group_chat_manager_factory(
self._final_answer_prompt,
termination_condition,
)

def _to_config(self) -> MagenticOneGroupChatConfig:
participants = [participant.dump_component() for participant in self._participants]
termination_condition = self._termination_condition.dump_component() if self._termination_condition else None
return MagenticOneGroupChatConfig(
participants=participants,
model_client=self._model_client.dump_component(),
termination_condition=termination_condition,
max_turns=self._max_turns,
max_stalls=self._max_stalls,
final_answer_prompt=self._final_answer_prompt,
)

@classmethod
def _from_config(cls, config: MagenticOneGroupChatConfig) -> Self:
participants = [ChatAgent.load_component(participant) for participant in config.participants]
model_client = ChatCompletionClient.load_component(config.model_client)
termination_condition = (
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
)
return cls(
participants,
model_client,
termination_condition=termination_condition,
max_turns=config.max_turns,
max_stalls=config.max_stalls,
final_answer_prompt=config.final_answer_prompt,
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Any, Callable, List, Mapping

from autogen_core import Component, ComponentModel
from pydantic import BaseModel
from typing_extensions import Self

from ...base import ChatAgent, TerminationCondition
from ...messages import AgentEvent, ChatMessage
from ...state import RoundRobinManagerState
Expand Down Expand Up @@ -61,7 +65,15 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
return current_speaker


class RoundRobinGroupChat(BaseGroupChat):
class RoundRobinGroupChatConfig(BaseModel):
"""The declarative configuration RoundRobinGroupChat."""

participants: List[ComponentModel]
termination_condition: ComponentModel | None = None
max_turns: int | None = None


class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
"""A team that runs a group chat with participants taking turns in a round-robin fashion
to publish a message to all.
Expand Down Expand Up @@ -133,6 +145,9 @@ async def main() -> None:
asyncio.run(main())
"""

component_config_schema = RoundRobinGroupChatConfig
component_provider_override = "autogen_agentchat.teams.RoundRobinGroupChat"

def __init__(
self,
participants: List[ChatAgent],
Expand Down Expand Up @@ -166,3 +181,20 @@ def _factory() -> RoundRobinGroupChatManager:
)

return _factory

def _to_config(self) -> RoundRobinGroupChatConfig:
participants = [participant.dump_component() for participant in self._participants]
termination_condition = self._termination_condition.dump_component() if self._termination_condition else None
return RoundRobinGroupChatConfig(
participants=participants,
termination_condition=termination_condition,
max_turns=self._max_turns,
)

@classmethod
def _from_config(cls, config: RoundRobinGroupChatConfig) -> Self:
participants = [ChatAgent.load_component(participant) for participant in config.participants]
termination_condition = (
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
)
return cls(participants, termination_condition=termination_condition, max_turns=config.max_turns)
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
import re
from typing import Any, Callable, Dict, List, Mapping, Sequence

from autogen_core import Component, ComponentModel
from autogen_core.models import ChatCompletionClient, SystemMessage
from pydantic import BaseModel
from typing_extensions import Self

from ... import TRACE_LOGGER_NAME
from ...agents import BaseChatAgent
from ...base import ChatAgent, TerminationCondition
from ...messages import (
AgentEvent,
Expand Down Expand Up @@ -184,7 +188,19 @@ def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dic
return mentions


class SelectorGroupChat(BaseGroupChat):
class SelectorGroupChatConfig(BaseModel):
"""The declarative configuration for SelectorGroupChat."""

participants: List[ComponentModel]
model_client: ComponentModel
termination_condition: ComponentModel | None = None
max_turns: int | None = None
selector_prompt: str
allow_repeated_speaker: bool
# selector_func: ComponentModel | None


class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
"""A group chat team that have participants takes turn to publish a message
to all, using a ChatCompletion model to select the next speaker after each message.
Expand Down Expand Up @@ -321,6 +337,9 @@ def selector_func(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:
asyncio.run(main())
"""

component_config_schema = SelectorGroupChatConfig
component_provider_override = "autogen_agentchat.teams.SelectorGroupChat"

def __init__(
self,
participants: List[ChatAgent],
Expand Down Expand Up @@ -381,3 +400,30 @@ def _create_group_chat_manager_factory(
self._allow_repeated_speaker,
self._selector_func,
)

def _to_config(self) -> SelectorGroupChatConfig:
return SelectorGroupChatConfig(
participants=[participant.dump_component() for participant in self._participants],
model_client=self._model_client.dump_component(),
termination_condition=self._termination_condition.dump_component() if self._termination_condition else None,
max_turns=self._max_turns,
selector_prompt=self._selector_prompt,
allow_repeated_speaker=self._allow_repeated_speaker,
# selector_func=self._selector_func.dump_component() if self._selector_func else None,
)

@classmethod
def _from_config(cls, config: SelectorGroupChatConfig) -> Self:
return cls(
participants=[BaseChatAgent.load_component(participant) for participant in config.participants],
model_client=ChatCompletionClient.load_component(config.model_client),
termination_condition=TerminationCondition.load_component(config.termination_condition)
if config.termination_condition
else None,
max_turns=config.max_turns,
selector_prompt=config.selector_prompt,
allow_repeated_speaker=config.allow_repeated_speaker,
# selector_func=ComponentLoader.load_component(config.selector_func, Callable[[Sequence[AgentEvent | ChatMessage]], str | None])
# if config.selector_func
# else None,
)
Loading

0 comments on commit 2814243

Please sign in to comment.