From 1da58450568ca2b902935f0b2bac80dd4dede49b Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Sat, 28 Feb 2026 13:40:28 +0100 Subject: [PATCH 1/3] feat: implement message and communication domain models (#58) Add pure domain models for inter-agent communication covering message format, channels, and loop prevention config per DESIGN_SPEC 5.3-5.5. Source models: enums, Attachment, MessageMetadata, Message (with from/sender alias), Channel, and 8 config models including LoopPreventionConfig with ancestry_tracking enforcement. 140 unit tests at 100% module coverage. --- CLAUDE.md | 4 + src/ai_company/communication/__init__.py | 43 ++ src/ai_company/communication/channel.py | 51 +++ src/ai_company/communication/config.py | 290 +++++++++++++ src/ai_company/communication/enums.py | 72 ++++ src/ai_company/communication/message.py | 161 +++++++ tests/unit/communication/__init__.py | 0 tests/unit/communication/conftest.py | 147 +++++++ tests/unit/communication/test_channel.py | 120 ++++++ tests/unit/communication/test_config.py | 511 +++++++++++++++++++++++ tests/unit/communication/test_enums.py | 97 +++++ tests/unit/communication/test_message.py | 366 ++++++++++++++++ 12 files changed, 1862 insertions(+) create mode 100644 src/ai_company/communication/channel.py create mode 100644 src/ai_company/communication/config.py create mode 100644 src/ai_company/communication/enums.py create mode 100644 src/ai_company/communication/message.py create mode 100644 tests/unit/communication/__init__.py create mode 100644 tests/unit/communication/conftest.py create mode 100644 tests/unit/communication/test_channel.py create mode 100644 tests/unit/communication/test_config.py create mode 100644 tests/unit/communication/test_enums.py create mode 100644 tests/unit/communication/test_message.py diff --git a/CLAUDE.md b/CLAUDE.md index ef556e979c..badc62c0ce 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -41,6 +41,10 @@ src/ai_company/ tools/ # Tool registry, MCP integration, role-based access ``` +## Shell Usage + +- **NEVER use `cd` in Bash commands** — the working directory is already set to the project root. Use absolute paths or run commands directly. Do NOT prefix commands with `cd C:/Users/Aurelio/ai-company &&`. + ## Code Conventions - **No `from __future__ import annotations`** — Python 3.14 has PEP 649 diff --git a/src/ai_company/communication/__init__.py b/src/ai_company/communication/__init__.py index e69de29bb2..58312f5d0a 100644 --- a/src/ai_company/communication/__init__.py +++ b/src/ai_company/communication/__init__.py @@ -0,0 +1,43 @@ +"""Communication domain models for the AI company framework.""" + +from ai_company.communication.channel import Channel +from ai_company.communication.config import ( + CircuitBreakerConfig, + CommunicationConfig, + HierarchyConfig, + LoopPreventionConfig, + MeetingsConfig, + MeetingTypeConfig, + MessageBusConfig, + RateLimitConfig, +) +from ai_company.communication.enums import ( + AttachmentType, + ChannelType, + CommunicationPattern, + MessageBusBackend, + MessagePriority, + MessageType, +) +from ai_company.communication.message import Attachment, Message, MessageMetadata + +__all__ = [ + "Attachment", + "AttachmentType", + "Channel", + "ChannelType", + "CircuitBreakerConfig", + "CommunicationConfig", + "CommunicationPattern", + "HierarchyConfig", + "LoopPreventionConfig", + "MeetingTypeConfig", + "MeetingsConfig", + "Message", + "MessageBusBackend", + "MessageBusConfig", + "MessageMetadata", + "MessagePriority", + "MessageType", + "RateLimitConfig", +] diff --git a/src/ai_company/communication/channel.py b/src/ai_company/communication/channel.py new file mode 100644 index 0000000000..24e8846619 --- /dev/null +++ b/src/ai_company/communication/channel.py @@ -0,0 +1,51 @@ +"""Channel domain model.""" + +from collections import Counter +from typing import Self + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from ai_company.communication.enums import ChannelType + + +class Channel(BaseModel): + """A named communication channel that agents can subscribe to. + + Attributes: + name: Channel name (e.g. ``"#engineering"``). + type: Channel delivery semantics. + subscribers: Agent IDs subscribed to this channel. + """ + + model_config = ConfigDict(frozen=True) + + name: str = Field(min_length=1, description="Channel name") + type: ChannelType = Field( + default=ChannelType.TOPIC, + description="Channel delivery semantics", + ) + subscribers: tuple[str, ...] = Field( + default=(), + description="Agent IDs subscribed to this channel", + ) + + @model_validator(mode="after") + def _validate_name_not_blank(self) -> Self: + """Ensure name is not whitespace-only.""" + if not self.name.strip(): + msg = "name must not be whitespace-only" + raise ValueError(msg) + return self + + @model_validator(mode="after") + def _validate_subscribers(self) -> Self: + """Ensure subscriber entries are non-blank and unique.""" + for sub in self.subscribers: + if not sub.strip(): + msg = "Empty or whitespace-only entry in subscribers" + raise ValueError(msg) + if len(self.subscribers) != len(set(self.subscribers)): + dupes = sorted(s for s, c in Counter(self.subscribers).items() if c > 1) + msg = f"Duplicate entries in subscribers: {dupes}" + raise ValueError(msg) + return self diff --git a/src/ai_company/communication/config.py b/src/ai_company/communication/config.py new file mode 100644 index 0000000000..296a45f962 --- /dev/null +++ b/src/ai_company/communication/config.py @@ -0,0 +1,290 @@ +"""Communication configuration models (DESIGN_SPEC Sections 5.4, 5.5).""" + +from collections import Counter +from typing import Self + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from ai_company.communication.enums import ( + CommunicationPattern, + MessageBusBackend, +) + +# Default channels from DESIGN_SPEC Section 5.4. +_DEFAULT_CHANNELS: tuple[str, ...] = ( + "#all-hands", + "#engineering", + "#product", + "#design", + "#incidents", + "#code-review", + "#watercooler", +) + + +class MessageBusConfig(BaseModel): + """Message bus backend configuration. + + Maps to DESIGN_SPEC Section 5.4 ``message_bus``. + + Attributes: + backend: Transport backend to use. + channels: Pre-defined channel names. + """ + + model_config = ConfigDict(frozen=True) + + backend: MessageBusBackend = Field( + default=MessageBusBackend.INTERNAL, + description="Transport backend", + ) + channels: tuple[str, ...] = Field( + default=_DEFAULT_CHANNELS, + description="Pre-defined channel names", + ) + + @model_validator(mode="after") + def _validate_channels(self) -> Self: + """Ensure channel names are non-blank and unique.""" + for ch in self.channels: + if not ch.strip(): + msg = "Empty or whitespace-only entry in channels" + raise ValueError(msg) + if len(self.channels) != len(set(self.channels)): + dupes = sorted(c for c, n in Counter(self.channels).items() if n > 1) + msg = f"Duplicate entries in channels: {dupes}" + raise ValueError(msg) + return self + + +class MeetingTypeConfig(BaseModel): + """Configuration for a single meeting type. + + Maps to DESIGN_SPEC Section 5.4 ``meetings.types[]``. Exactly one of + ``frequency`` or ``trigger`` must be set. + + Attributes: + name: Meeting type name (e.g. ``"daily_standup"``). + frequency: Recurrence schedule (mutually exclusive with trigger). + trigger: Event trigger (mutually exclusive with frequency). + participants: Participant role or agent identifiers. + duration_tokens: Token budget for the meeting. + """ + + model_config = ConfigDict(frozen=True) + + name: str = Field(min_length=1, description="Meeting type name") + frequency: str | None = Field( + default=None, + min_length=1, + description="Recurrence schedule", + ) + trigger: str | None = Field( + default=None, + min_length=1, + description="Event trigger", + ) + participants: tuple[str, ...] = Field( + default=(), + description="Participant role or agent identifiers", + ) + duration_tokens: int = Field( + default=2000, + gt=0, + description="Token budget for the meeting", + ) + + @model_validator(mode="after") + def _validate_frequency_or_trigger(self) -> Self: + """Exactly one of frequency or trigger must be set.""" + if self.frequency is not None and self.trigger is not None: + msg = "Only one of frequency or trigger may be set, not both" + raise ValueError(msg) + if self.frequency is None and self.trigger is None: + msg = "Exactly one of frequency or trigger must be set" + raise ValueError(msg) + return self + + +class MeetingsConfig(BaseModel): + """Meetings subsystem configuration. + + Maps to DESIGN_SPEC Section 5.4 ``meetings``. + + Attributes: + enabled: Whether the meetings subsystem is active. + types: Configured meeting types (unique by name). + """ + + model_config = ConfigDict(frozen=True) + + enabled: bool = Field(default=True, description="Meetings subsystem active") + types: tuple[MeetingTypeConfig, ...] = Field( + default=(), + description="Configured meeting types", + ) + + @model_validator(mode="after") + def _validate_unique_meeting_names(self) -> Self: + """Ensure meeting type names are unique.""" + names = [mt.name for mt in self.types] + if len(names) != len(set(names)): + dupes = sorted(n for n, c in Counter(names).items() if c > 1) + msg = f"Duplicate meeting type names: {dupes}" + raise ValueError(msg) + return self + + +class HierarchyConfig(BaseModel): + """Hierarchy enforcement configuration. + + Maps to DESIGN_SPEC Section 5.4 ``hierarchy``. + + Attributes: + enforce_chain_of_command: Whether chain-of-command is enforced. + allow_skip_level: Whether skip-level messaging is allowed. + """ + + model_config = ConfigDict(frozen=True) + + enforce_chain_of_command: bool = Field( + default=True, + description="Enforce chain-of-command", + ) + allow_skip_level: bool = Field( + default=False, + description="Allow skip-level messaging", + ) + + +class RateLimitConfig(BaseModel): + """Per-pair message rate limit configuration. + + Maps to DESIGN_SPEC Section 5.5 ``rate_limit``. + + Attributes: + max_per_pair_per_minute: Maximum messages per agent pair per minute. + burst_allowance: Extra burst capacity above the rate limit. + """ + + model_config = ConfigDict(frozen=True) + + max_per_pair_per_minute: int = Field( + default=10, + gt=0, + description="Max messages per agent pair per minute", + ) + burst_allowance: int = Field( + default=3, + ge=0, + description="Extra burst capacity", + ) + + +class CircuitBreakerConfig(BaseModel): + """Circuit breaker configuration for agent-pair communication. + + Maps to DESIGN_SPEC Section 5.5 ``circuit_breaker``. + + Attributes: + bounce_threshold: Bounce count before the circuit opens. + cooldown_seconds: Seconds to wait before retrying after trip. + """ + + model_config = ConfigDict(frozen=True) + + bounce_threshold: int = Field( + default=3, + gt=0, + description="Bounce count before circuit opens", + ) + cooldown_seconds: int = Field( + default=300, + gt=0, + description="Cooldown period in seconds", + ) + + +class LoopPreventionConfig(BaseModel): + """Loop prevention safeguards. + + Maps to DESIGN_SPEC Section 5.5. ``ancestry_tracking`` is always on + and cannot be disabled. + + Attributes: + max_delegation_depth: Hard limit on delegation chain length. + rate_limit: Per-pair rate limit settings. + dedup_window_seconds: Deduplication window in seconds. + circuit_breaker: Circuit breaker settings. + ancestry_tracking: Must always be ``True``. + """ + + model_config = ConfigDict(frozen=True) + + max_delegation_depth: int = Field( + default=5, + gt=0, + description="Hard limit on delegation chain length", + ) + rate_limit: RateLimitConfig = Field( + default_factory=RateLimitConfig, + description="Per-pair rate limit settings", + ) + dedup_window_seconds: int = Field( + default=60, + gt=0, + description="Deduplication window in seconds", + ) + circuit_breaker: CircuitBreakerConfig = Field( + default_factory=CircuitBreakerConfig, + description="Circuit breaker settings", + ) + ancestry_tracking: bool = Field( + default=True, + description="Task ancestry tracking (always on)", + ) + + @model_validator(mode="after") + def _validate_ancestry_tracking(self) -> Self: + """Ancestry tracking must always be enabled (spec: not configurable).""" + if not self.ancestry_tracking: + msg = "ancestry_tracking must be True (always on, not configurable)" + raise ValueError(msg) + return self + + +class CommunicationConfig(BaseModel): + """Top-level communication configuration. + + Aggregates DESIGN_SPEC Sections 5.4 and 5.5 under a single model. + + Attributes: + default_pattern: High-level communication pattern. + message_bus: Message bus configuration. + meetings: Meetings subsystem configuration. + hierarchy: Hierarchy enforcement settings. + loop_prevention: Loop prevention safeguards. + """ + + model_config = ConfigDict(frozen=True) + + default_pattern: CommunicationPattern = Field( + default=CommunicationPattern.HYBRID, + description="High-level communication pattern", + ) + message_bus: MessageBusConfig = Field( + default_factory=MessageBusConfig, + description="Message bus configuration", + ) + meetings: MeetingsConfig = Field( + default_factory=MeetingsConfig, + description="Meetings subsystem configuration", + ) + hierarchy: HierarchyConfig = Field( + default_factory=HierarchyConfig, + description="Hierarchy enforcement settings", + ) + loop_prevention: LoopPreventionConfig = Field( + default_factory=LoopPreventionConfig, + description="Loop prevention safeguards", + ) diff --git a/src/ai_company/communication/enums.py b/src/ai_company/communication/enums.py new file mode 100644 index 0000000000..f6f3fee6b3 --- /dev/null +++ b/src/ai_company/communication/enums.py @@ -0,0 +1,72 @@ +"""Communication domain enumerations.""" + +from enum import StrEnum + + +class MessageType(StrEnum): + """Type of inter-agent message. + + Maps to the ``type`` field in DESIGN_SPEC Section 5.3. + """ + + TASK_UPDATE = "task_update" + QUESTION = "question" + ANNOUNCEMENT = "announcement" + REVIEW_REQUEST = "review_request" + APPROVAL = "approval" + DELEGATION = "delegation" + STATUS_REPORT = "status_report" + ESCALATION = "escalation" + + +class MessagePriority(StrEnum): + """Priority level for messages. + + Separate from :class:`ai_company.core.enums.Priority` which uses + ``"medium"``; message priority uses ``"normal"`` per DESIGN_SPEC 5.3. + """ + + LOW = "low" + NORMAL = "normal" + HIGH = "high" + URGENT = "urgent" + + +class ChannelType(StrEnum): + """Channel delivery semantics.""" + + TOPIC = "topic" + DIRECT = "direct" + BROADCAST = "broadcast" + + +class AttachmentType(StrEnum): + """Type of message attachment.""" + + ARTIFACT = "artifact" + FILE = "file" + LINK = "link" + + +class CommunicationPattern(StrEnum): + """High-level communication pattern for the company. + + Maps to DESIGN_SPEC Section 5.1. + """ + + EVENT_DRIVEN = "event_driven" + HIERARCHICAL = "hierarchical" + MEETING_BASED = "meeting_based" + HYBRID = "hybrid" + + +class MessageBusBackend(StrEnum): + """Message bus backend implementation. + + Maps to DESIGN_SPEC Section 5.4 ``message_bus.backend``. + """ + + INTERNAL = "internal" + REDIS = "redis" + RABBITMQ = "rabbitmq" + KAFKA = "kafka" diff --git a/src/ai_company/communication/message.py b/src/ai_company/communication/message.py new file mode 100644 index 0000000000..2b67b514df --- /dev/null +++ b/src/ai_company/communication/message.py @@ -0,0 +1,161 @@ +"""Message domain models (DESIGN_SPEC Section 5.3).""" + +from collections import Counter +from datetime import datetime # noqa: TC003 +from typing import Self +from uuid import UUID, uuid4 + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from ai_company.communication.enums import ( + AttachmentType, + MessagePriority, + MessageType, +) + + +class Attachment(BaseModel): + """A reference attached to a message. + + Attributes: + type: The kind of attachment. + ref: Reference identifier (e.g. artifact ID, URL, file path). + """ + + model_config = ConfigDict(frozen=True) + + type: AttachmentType = Field(description="Kind of attachment") + ref: str = Field(min_length=1, description="Reference identifier") + + @model_validator(mode="after") + def _validate_ref_not_blank(self) -> Self: + """Ensure ref is not whitespace-only.""" + if not self.ref.strip(): + msg = "ref must not be whitespace-only" + raise ValueError(msg) + return self + + +class MessageMetadata(BaseModel): + """Optional metadata carried with a message. + + Attributes: + task_id: Related task identifier. + project_id: Related project identifier. + tokens_used: LLM tokens consumed producing the message. + cost_usd: Estimated USD cost of the message. + extra: Immutable key-value pairs for arbitrary metadata. + """ + + model_config = ConfigDict(frozen=True) + + task_id: str | None = Field( + default=None, + min_length=1, + description="Related task identifier", + ) + project_id: str | None = Field( + default=None, + min_length=1, + description="Related project identifier", + ) + tokens_used: int | None = Field( + default=None, + ge=0, + description="LLM tokens consumed", + ) + cost_usd: float | None = Field( + default=None, + ge=0.0, + description="Estimated USD cost", + ) + extra: tuple[tuple[str, str], ...] = Field( + default=(), + description="Immutable key-value pairs for arbitrary metadata", + ) + + @model_validator(mode="after") + def _validate_optional_strings(self) -> Self: + """Ensure optional string fields are not whitespace-only.""" + for field_name in ("task_id", "project_id"): + value = getattr(self, field_name) + if value is not None and not value.strip(): + msg = f"{field_name} must not be whitespace-only" + raise ValueError(msg) + return self + + @model_validator(mode="after") + def _validate_extra(self) -> Self: + """Ensure extra keys are non-blank and unique.""" + keys: list[str] = [] + for key, _value in self.extra: + if not key.strip(): + msg = "extra keys must not be blank" + raise ValueError(msg) + keys.append(key) + if len(keys) != len(set(keys)): + dupes = sorted(k for k, c in Counter(keys).items() if c > 1) + msg = f"Duplicate keys in extra: {dupes}" + raise ValueError(msg) + return self + + +class Message(BaseModel): + """An inter-agent message. + + Field schema matches DESIGN_SPEC Section 5.3. The ``sender`` field + is aliased to ``"from"`` for JSON compatibility with the spec format. + + Attributes: + id: Unique message identifier. + timestamp: When the message was created. + sender: Agent ID of the sender (aliased to ``"from"`` in JSON). + to: Recipient agent or channel identifier. + type: Message type classification. + priority: Message priority level. + channel: Channel the message is sent through. + content: Message body text. + attachments: Attached references. + metadata: Optional message metadata. + """ + + model_config = ConfigDict(frozen=True, populate_by_name=True) + + id: UUID = Field( + default_factory=uuid4, + description="Unique message identifier", + ) + timestamp: datetime = Field(description="When the message was created") + sender: str = Field( + min_length=1, + alias="from", + description="Sender agent ID", + ) + to: str = Field(min_length=1, description="Recipient agent or channel") + type: MessageType = Field(description="Message type classification") + priority: MessagePriority = Field( + default=MessagePriority.NORMAL, + description="Message priority level", + ) + channel: str = Field( + min_length=1, + description="Channel the message is sent through", + ) + content: str = Field(min_length=1, description="Message body text") + attachments: tuple[Attachment, ...] = Field( + default=(), + description="Attached references", + ) + metadata: MessageMetadata = Field( + default_factory=MessageMetadata, + description="Optional message metadata", + ) + + @model_validator(mode="after") + def _validate_strings_not_blank(self) -> Self: + """Ensure required string fields are not whitespace-only.""" + for field_name in ("sender", "to", "channel", "content"): + if not getattr(self, field_name).strip(): + msg = f"{field_name} must not be whitespace-only" + raise ValueError(msg) + return self diff --git a/tests/unit/communication/__init__.py b/tests/unit/communication/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/communication/conftest.py b/tests/unit/communication/conftest.py new file mode 100644 index 0000000000..2c2734d62e --- /dev/null +++ b/tests/unit/communication/conftest.py @@ -0,0 +1,147 @@ +"""Unit test configuration and fixtures for communication models.""" + +from datetime import UTC, datetime + +import pytest +from polyfactory.factories.pydantic_factory import ModelFactory + +from ai_company.communication.channel import Channel +from ai_company.communication.config import ( + CircuitBreakerConfig, + CommunicationConfig, + HierarchyConfig, + LoopPreventionConfig, + MeetingsConfig, + MeetingTypeConfig, + MessageBusConfig, + RateLimitConfig, +) +from ai_company.communication.enums import ( + ChannelType, + MessagePriority, + MessageType, +) +from ai_company.communication.message import Attachment, Message, MessageMetadata + +# ── Factories ────────────────────────────────────────────────────── + + +class AttachmentFactory(ModelFactory): + __model__ = Attachment + + +class MessageMetadataFactory(ModelFactory): + __model__ = MessageMetadata + task_id = None + project_id = None + tokens_used = None + cost_usd = None + extra = () + + +class MessageFactory(ModelFactory): + __model__ = Message + priority = MessagePriority.NORMAL + attachments = () + metadata = MessageMetadataFactory + + +class ChannelFactory(ModelFactory): + __model__ = Channel + type = ChannelType.TOPIC + subscribers = () + + +class MessageBusConfigFactory(ModelFactory): + __model__ = MessageBusConfig + + +class MeetingTypeConfigFactory(ModelFactory): + __model__ = MeetingTypeConfig + frequency = "daily" + trigger = None + + +class MeetingsConfigFactory(ModelFactory): + __model__ = MeetingsConfig + types = () + + +class HierarchyConfigFactory(ModelFactory): + __model__ = HierarchyConfig + + +class RateLimitConfigFactory(ModelFactory): + __model__ = RateLimitConfig + + +class CircuitBreakerConfigFactory(ModelFactory): + __model__ = CircuitBreakerConfig + + +class LoopPreventionConfigFactory(ModelFactory): + __model__ = LoopPreventionConfig + ancestry_tracking = True + + +class CommunicationConfigFactory(ModelFactory): + __model__ = CommunicationConfig + meetings = MeetingsConfigFactory + loop_prevention = LoopPreventionConfigFactory + + +# ── Sample Fixtures ──────────────────────────────────────────────── + + +@pytest.fixture +def sample_attachment() -> Attachment: + return Attachment(type="artifact", ref="pr-42") + + +@pytest.fixture +def sample_metadata() -> MessageMetadata: + return MessageMetadata( + task_id="task-123", + project_id="proj-456", + tokens_used=1200, + cost_usd=0.018, + ) + + +@pytest.fixture +def sample_message(sample_metadata: MessageMetadata) -> Message: + return Message( + timestamp=datetime(2026, 2, 27, 10, 30, tzinfo=UTC), + sender="sarah_chen", + to="engineering", + type=MessageType.TASK_UPDATE, + priority=MessagePriority.NORMAL, + channel="#backend", + content="Completed API endpoint for user authentication.", + attachments=(Attachment(type="artifact", ref="pr-42"),), + metadata=sample_metadata, + ) + + +@pytest.fixture +def sample_channel() -> Channel: + return Channel( + name="#engineering", + type=ChannelType.TOPIC, + subscribers=("sarah_chen", "backend_lead"), + ) + + +@pytest.fixture +def sample_meeting_type() -> MeetingTypeConfig: + return MeetingTypeConfig( + name="daily_standup", + frequency="per_sprint_day", + participants=("engineering", "qa"), + duration_tokens=2000, + ) + + +@pytest.fixture +def sample_communication_config() -> CommunicationConfig: + return CommunicationConfig() diff --git a/tests/unit/communication/test_channel.py b/tests/unit/communication/test_channel.py new file mode 100644 index 0000000000..df17111652 --- /dev/null +++ b/tests/unit/communication/test_channel.py @@ -0,0 +1,120 @@ +"""Tests for the Channel domain model.""" + +import pytest +from pydantic import ValidationError + +from ai_company.communication.channel import Channel +from ai_company.communication.enums import ChannelType + +pytestmark = pytest.mark.timeout(30) + + +# ── Channel: Construction & Defaults ──────────────────────────── + + +@pytest.mark.unit +class TestChannelConstruction: + def test_minimal_valid(self) -> None: + ch = Channel(name="#engineering") + assert ch.name == "#engineering" + assert ch.type is ChannelType.TOPIC + assert ch.subscribers == () + + def test_all_fields_set(self) -> None: + ch = Channel( + name="#backend", + type=ChannelType.DIRECT, + subscribers=("agent-a", "agent-b"), + ) + assert ch.name == "#backend" + assert ch.type is ChannelType.DIRECT + assert ch.subscribers == ("agent-a", "agent-b") + + +# ── Channel: Validation ───────────────────────────────────────── + + +@pytest.mark.unit +class TestChannelValidation: + def test_empty_name_rejected(self) -> None: + with pytest.raises(ValidationError): + Channel(name="") + + def test_whitespace_name_rejected(self) -> None: + with pytest.raises(ValidationError, match="name must not be whitespace-only"): + Channel(name=" ") + + def test_whitespace_subscriber_rejected(self) -> None: + with pytest.raises( + ValidationError, match="Empty or whitespace-only entry in subscribers" + ): + Channel(name="#test", subscribers=("agent-a", " ")) + + def test_duplicate_subscribers_rejected(self) -> None: + with pytest.raises(ValidationError, match="Duplicate entries in subscribers"): + Channel(name="#test", subscribers=("agent-a", "agent-b", "agent-a")) + + def test_valid_subscribers(self) -> None: + ch = Channel(name="#test", subscribers=("a", "b", "c")) + assert ch.subscribers == ("a", "b", "c") + + +# ── Channel: Immutability ─────────────────────────────────────── + + +@pytest.mark.unit +class TestChannelImmutability: + def test_frozen(self) -> None: + ch = Channel(name="#test") + with pytest.raises(ValidationError): + ch.name = "#other" # type: ignore[misc] + + def test_model_copy(self) -> None: + original = Channel(name="#test", subscribers=("a",)) + updated = original.model_copy(update={"name": "#updated"}) + assert updated.name == "#updated" + assert original.name == "#test" + + +# ── Channel: Serialization ────────────────────────────────────── + + +@pytest.mark.unit +class TestChannelSerialization: + def test_json_roundtrip(self) -> None: + ch = Channel( + name="#engineering", + type=ChannelType.BROADCAST, + subscribers=("a", "b"), + ) + restored = Channel.model_validate_json(ch.model_dump_json()) + assert restored == ch + + def test_model_dump(self) -> None: + ch = Channel(name="#test", type=ChannelType.DIRECT) + dumped = ch.model_dump() + assert dumped["type"] == "direct" + + +# ── Channel: Factory ──────────────────────────────────────────── + + +@pytest.mark.unit +class TestChannelFactory: + def test_factory(self) -> None: + from tests.unit.communication.conftest import ChannelFactory + + ch = ChannelFactory.build() + assert isinstance(ch, Channel) + assert len(ch.name) >= 1 + + +# ── Channel: Fixtures ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestChannelFixtures: + def test_sample_channel(self, sample_channel: Channel) -> None: + assert sample_channel.name == "#engineering" + assert sample_channel.type is ChannelType.TOPIC + assert len(sample_channel.subscribers) == 2 diff --git a/tests/unit/communication/test_config.py b/tests/unit/communication/test_config.py new file mode 100644 index 0000000000..e7f82acabd --- /dev/null +++ b/tests/unit/communication/test_config.py @@ -0,0 +1,511 @@ +"""Tests for the communication configuration models.""" + +import pytest +from pydantic import ValidationError + +from ai_company.communication.config import ( + CircuitBreakerConfig, + CommunicationConfig, + HierarchyConfig, + LoopPreventionConfig, + MeetingsConfig, + MeetingTypeConfig, + MessageBusConfig, + RateLimitConfig, +) +from ai_company.communication.enums import ( + CommunicationPattern, + MessageBusBackend, +) + +pytestmark = pytest.mark.timeout(30) + + +# ── MessageBusConfig ──────────────────────────────────────────── + + +@pytest.mark.unit +class TestMessageBusConfigDefaults: + def test_defaults(self) -> None: + cfg = MessageBusConfig() + assert cfg.backend is MessageBusBackend.INTERNAL + assert len(cfg.channels) == 7 + assert "#all-hands" in cfg.channels + assert "#engineering" in cfg.channels + assert "#watercooler" in cfg.channels + + def test_custom_values(self) -> None: + cfg = MessageBusConfig( + backend=MessageBusBackend.REDIS, + channels=("#ops", "#alerts"), + ) + assert cfg.backend is MessageBusBackend.REDIS + assert cfg.channels == ("#ops", "#alerts") + + +@pytest.mark.unit +class TestMessageBusConfigValidation: + def test_whitespace_channel_rejected(self) -> None: + with pytest.raises( + ValidationError, match="Empty or whitespace-only entry in channels" + ): + MessageBusConfig(channels=("#valid", " ")) + + def test_duplicate_channels_rejected(self) -> None: + with pytest.raises(ValidationError, match="Duplicate entries in channels"): + MessageBusConfig(channels=("#a", "#b", "#a")) + + def test_empty_channels_allowed(self) -> None: + cfg = MessageBusConfig(channels=()) + assert cfg.channels == () + + +@pytest.mark.unit +class TestMessageBusConfigImmutability: + def test_frozen(self) -> None: + cfg = MessageBusConfig() + with pytest.raises(ValidationError): + cfg.backend = MessageBusBackend.KAFKA # type: ignore[misc] + + def test_model_copy(self) -> None: + original = MessageBusConfig() + updated = original.model_copy(update={"backend": MessageBusBackend.RABBITMQ}) + assert updated.backend is MessageBusBackend.RABBITMQ + assert original.backend is MessageBusBackend.INTERNAL + + +@pytest.mark.unit +class TestMessageBusConfigSerialization: + def test_json_roundtrip(self) -> None: + cfg = MessageBusConfig( + backend=MessageBusBackend.KAFKA, + channels=("#a", "#b"), + ) + restored = MessageBusConfig.model_validate_json(cfg.model_dump_json()) + assert restored == cfg + + def test_factory(self) -> None: + from tests.unit.communication.conftest import MessageBusConfigFactory + + cfg = MessageBusConfigFactory.build() + assert isinstance(cfg, MessageBusConfig) + + +# ── MeetingTypeConfig ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestMeetingTypeConfigConstruction: + def test_with_frequency(self) -> None: + mt = MeetingTypeConfig(name="standup", frequency="daily") + assert mt.name == "standup" + assert mt.frequency == "daily" + assert mt.trigger is None + assert mt.participants == () + assert mt.duration_tokens == 2000 + + def test_with_trigger(self) -> None: + mt = MeetingTypeConfig(name="code_review", trigger="on_pr") + assert mt.trigger == "on_pr" + assert mt.frequency is None + + def test_custom_values(self) -> None: + mt = MeetingTypeConfig( + name="planning", + frequency="bi_weekly", + participants=("all",), + duration_tokens=5000, + ) + assert mt.participants == ("all",) + assert mt.duration_tokens == 5000 + + +@pytest.mark.unit +class TestMeetingTypeConfigValidation: + def test_both_frequency_and_trigger_rejected(self) -> None: + with pytest.raises( + ValidationError, + match="Only one of frequency or trigger may be set", + ): + MeetingTypeConfig(name="bad", frequency="daily", trigger="on_pr") + + def test_neither_frequency_nor_trigger_rejected(self) -> None: + with pytest.raises( + ValidationError, + match="Exactly one of frequency or trigger must be set", + ): + MeetingTypeConfig(name="bad") + + def test_empty_name_rejected(self) -> None: + with pytest.raises(ValidationError): + MeetingTypeConfig(name="", frequency="daily") + + def test_zero_duration_rejected(self) -> None: + with pytest.raises(ValidationError): + MeetingTypeConfig(name="bad", frequency="daily", duration_tokens=0) + + def test_negative_duration_rejected(self) -> None: + with pytest.raises(ValidationError): + MeetingTypeConfig(name="bad", frequency="daily", duration_tokens=-1) + + +@pytest.mark.unit +class TestMeetingTypeConfigImmutability: + def test_frozen(self) -> None: + mt = MeetingTypeConfig(name="standup", frequency="daily") + with pytest.raises(ValidationError): + mt.name = "new" # type: ignore[misc] + + def test_model_copy(self) -> None: + original = MeetingTypeConfig(name="standup", frequency="daily") + updated = original.model_copy(update={"duration_tokens": 3000}) + assert updated.duration_tokens == 3000 + assert original.duration_tokens == 2000 + + +@pytest.mark.unit +class TestMeetingTypeConfigSerialization: + def test_json_roundtrip(self) -> None: + mt = MeetingTypeConfig( + name="standup", + frequency="daily", + participants=("eng",), + duration_tokens=1500, + ) + restored = MeetingTypeConfig.model_validate_json(mt.model_dump_json()) + assert restored == mt + + def test_factory(self) -> None: + from tests.unit.communication.conftest import MeetingTypeConfigFactory + + mt = MeetingTypeConfigFactory.build() + assert isinstance(mt, MeetingTypeConfig) + + +# ── MeetingsConfig ────────────────────────────────────────────── + + +@pytest.mark.unit +class TestMeetingsConfigConstruction: + def test_defaults(self) -> None: + cfg = MeetingsConfig() + assert cfg.enabled is True + assert cfg.types == () + + def test_custom_values(self) -> None: + mt = MeetingTypeConfig(name="standup", frequency="daily") + cfg = MeetingsConfig(enabled=False, types=(mt,)) + assert cfg.enabled is False + assert len(cfg.types) == 1 + + +@pytest.mark.unit +class TestMeetingsConfigValidation: + def test_duplicate_meeting_names_rejected(self) -> None: + mt1 = MeetingTypeConfig(name="standup", frequency="daily") + mt2 = MeetingTypeConfig(name="standup", trigger="on_pr") + with pytest.raises(ValidationError, match="Duplicate meeting type names"): + MeetingsConfig(types=(mt1, mt2)) + + def test_unique_meeting_names_accepted(self) -> None: + mt1 = MeetingTypeConfig(name="standup", frequency="daily") + mt2 = MeetingTypeConfig(name="review", trigger="on_pr") + cfg = MeetingsConfig(types=(mt1, mt2)) + assert len(cfg.types) == 2 + + +@pytest.mark.unit +class TestMeetingsConfigImmutability: + def test_frozen(self) -> None: + cfg = MeetingsConfig() + with pytest.raises(ValidationError): + cfg.enabled = False # type: ignore[misc] + + def test_json_roundtrip(self) -> None: + mt = MeetingTypeConfig(name="standup", frequency="daily") + cfg = MeetingsConfig(types=(mt,)) + restored = MeetingsConfig.model_validate_json(cfg.model_dump_json()) + assert restored == cfg + + def test_factory(self) -> None: + from tests.unit.communication.conftest import MeetingsConfigFactory + + cfg = MeetingsConfigFactory.build() + assert isinstance(cfg, MeetingsConfig) + + +# ── HierarchyConfig ──────────────────────────────────────────── + + +@pytest.mark.unit +class TestHierarchyConfig: + def test_defaults(self) -> None: + cfg = HierarchyConfig() + assert cfg.enforce_chain_of_command is True + assert cfg.allow_skip_level is False + + def test_custom_values(self) -> None: + cfg = HierarchyConfig(enforce_chain_of_command=False, allow_skip_level=True) + assert cfg.enforce_chain_of_command is False + assert cfg.allow_skip_level is True + + def test_frozen(self) -> None: + cfg = HierarchyConfig() + with pytest.raises(ValidationError): + cfg.allow_skip_level = True # type: ignore[misc] + + def test_json_roundtrip(self) -> None: + cfg = HierarchyConfig(enforce_chain_of_command=False, allow_skip_level=True) + restored = HierarchyConfig.model_validate_json(cfg.model_dump_json()) + assert restored == cfg + + def test_model_copy(self) -> None: + original = HierarchyConfig() + updated = original.model_copy(update={"allow_skip_level": True}) + assert updated.allow_skip_level is True + assert original.allow_skip_level is False + + def test_factory(self) -> None: + from tests.unit.communication.conftest import HierarchyConfigFactory + + cfg = HierarchyConfigFactory.build() + assert isinstance(cfg, HierarchyConfig) + + +# ── RateLimitConfig ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestRateLimitConfig: + def test_defaults(self) -> None: + cfg = RateLimitConfig() + assert cfg.max_per_pair_per_minute == 10 + assert cfg.burst_allowance == 3 + + def test_custom_values(self) -> None: + cfg = RateLimitConfig(max_per_pair_per_minute=20, burst_allowance=5) + assert cfg.max_per_pair_per_minute == 20 + assert cfg.burst_allowance == 5 + + def test_zero_rate_rejected(self) -> None: + with pytest.raises(ValidationError): + RateLimitConfig(max_per_pair_per_minute=0) + + def test_negative_rate_rejected(self) -> None: + with pytest.raises(ValidationError): + RateLimitConfig(max_per_pair_per_minute=-1) + + def test_zero_burst_allowed(self) -> None: + cfg = RateLimitConfig(burst_allowance=0) + assert cfg.burst_allowance == 0 + + def test_negative_burst_rejected(self) -> None: + with pytest.raises(ValidationError): + RateLimitConfig(burst_allowance=-1) + + def test_frozen(self) -> None: + cfg = RateLimitConfig() + with pytest.raises(ValidationError): + cfg.max_per_pair_per_minute = 20 # type: ignore[misc] + + def test_json_roundtrip(self) -> None: + cfg = RateLimitConfig(max_per_pair_per_minute=15, burst_allowance=2) + restored = RateLimitConfig.model_validate_json(cfg.model_dump_json()) + assert restored == cfg + + def test_factory(self) -> None: + from tests.unit.communication.conftest import RateLimitConfigFactory + + cfg = RateLimitConfigFactory.build() + assert isinstance(cfg, RateLimitConfig) + + +# ── CircuitBreakerConfig ──────────────────────────────────────── + + +@pytest.mark.unit +class TestCircuitBreakerConfig: + def test_defaults(self) -> None: + cfg = CircuitBreakerConfig() + assert cfg.bounce_threshold == 3 + assert cfg.cooldown_seconds == 300 + + def test_custom_values(self) -> None: + cfg = CircuitBreakerConfig(bounce_threshold=5, cooldown_seconds=600) + assert cfg.bounce_threshold == 5 + assert cfg.cooldown_seconds == 600 + + def test_zero_threshold_rejected(self) -> None: + with pytest.raises(ValidationError): + CircuitBreakerConfig(bounce_threshold=0) + + def test_zero_cooldown_rejected(self) -> None: + with pytest.raises(ValidationError): + CircuitBreakerConfig(cooldown_seconds=0) + + def test_frozen(self) -> None: + cfg = CircuitBreakerConfig() + with pytest.raises(ValidationError): + cfg.bounce_threshold = 5 # type: ignore[misc] + + def test_json_roundtrip(self) -> None: + cfg = CircuitBreakerConfig(bounce_threshold=10, cooldown_seconds=120) + restored = CircuitBreakerConfig.model_validate_json(cfg.model_dump_json()) + assert restored == cfg + + def test_factory(self) -> None: + from tests.unit.communication.conftest import CircuitBreakerConfigFactory + + cfg = CircuitBreakerConfigFactory.build() + assert isinstance(cfg, CircuitBreakerConfig) + + +# ── LoopPreventionConfig ─────────────────────────────────────── + + +@pytest.mark.unit +class TestLoopPreventionConfigDefaults: + def test_defaults(self) -> None: + cfg = LoopPreventionConfig() + assert cfg.max_delegation_depth == 5 + assert isinstance(cfg.rate_limit, RateLimitConfig) + assert cfg.dedup_window_seconds == 60 + assert isinstance(cfg.circuit_breaker, CircuitBreakerConfig) + assert cfg.ancestry_tracking is True + + def test_custom_values(self) -> None: + cfg = LoopPreventionConfig( + max_delegation_depth=10, + rate_limit=RateLimitConfig(max_per_pair_per_minute=20), + dedup_window_seconds=120, + circuit_breaker=CircuitBreakerConfig(bounce_threshold=5), + ) + assert cfg.max_delegation_depth == 10 + assert cfg.rate_limit.max_per_pair_per_minute == 20 + assert cfg.dedup_window_seconds == 120 + assert cfg.circuit_breaker.bounce_threshold == 5 + + +@pytest.mark.unit +class TestLoopPreventionConfigValidation: + def test_ancestry_tracking_false_rejected(self) -> None: + with pytest.raises( + ValidationError, + match="ancestry_tracking must be True", + ): + LoopPreventionConfig(ancestry_tracking=False) + + def test_zero_delegation_depth_rejected(self) -> None: + with pytest.raises(ValidationError): + LoopPreventionConfig(max_delegation_depth=0) + + def test_zero_dedup_window_rejected(self) -> None: + with pytest.raises(ValidationError): + LoopPreventionConfig(dedup_window_seconds=0) + + +@pytest.mark.unit +class TestLoopPreventionConfigImmutability: + def test_frozen(self) -> None: + cfg = LoopPreventionConfig() + with pytest.raises(ValidationError): + cfg.max_delegation_depth = 10 # type: ignore[misc] + + def test_model_copy(self) -> None: + original = LoopPreventionConfig() + updated = original.model_copy(update={"max_delegation_depth": 10}) + assert updated.max_delegation_depth == 10 + assert original.max_delegation_depth == 5 + + +@pytest.mark.unit +class TestLoopPreventionConfigSerialization: + def test_json_roundtrip(self) -> None: + cfg = LoopPreventionConfig( + max_delegation_depth=8, + dedup_window_seconds=90, + ) + restored = LoopPreventionConfig.model_validate_json(cfg.model_dump_json()) + assert restored == cfg + + def test_factory(self) -> None: + from tests.unit.communication.conftest import LoopPreventionConfigFactory + + cfg = LoopPreventionConfigFactory.build() + assert isinstance(cfg, LoopPreventionConfig) + + +# ── CommunicationConfig ──────────────────────────────────────── + + +@pytest.mark.unit +class TestCommunicationConfigDefaults: + def test_defaults(self) -> None: + cfg = CommunicationConfig() + assert cfg.default_pattern is CommunicationPattern.HYBRID + assert isinstance(cfg.message_bus, MessageBusConfig) + assert isinstance(cfg.meetings, MeetingsConfig) + assert isinstance(cfg.hierarchy, HierarchyConfig) + assert isinstance(cfg.loop_prevention, LoopPreventionConfig) + + def test_custom_values(self) -> None: + cfg = CommunicationConfig( + default_pattern=CommunicationPattern.EVENT_DRIVEN, + message_bus=MessageBusConfig(backend=MessageBusBackend.REDIS), + hierarchy=HierarchyConfig(allow_skip_level=True), + ) + assert cfg.default_pattern is CommunicationPattern.EVENT_DRIVEN + assert cfg.message_bus.backend is MessageBusBackend.REDIS + assert cfg.hierarchy.allow_skip_level is True + + +@pytest.mark.unit +class TestCommunicationConfigImmutability: + def test_frozen(self) -> None: + cfg = CommunicationConfig() + with pytest.raises(ValidationError): + cfg.default_pattern = CommunicationPattern.HIERARCHICAL # type: ignore[misc] + + def test_model_copy(self) -> None: + original = CommunicationConfig() + updated = original.model_copy( + update={"default_pattern": CommunicationPattern.HIERARCHICAL} + ) + assert updated.default_pattern is CommunicationPattern.HIERARCHICAL + assert original.default_pattern is CommunicationPattern.HYBRID + + +@pytest.mark.unit +class TestCommunicationConfigSerialization: + def test_json_roundtrip(self) -> None: + cfg = CommunicationConfig( + default_pattern=CommunicationPattern.MEETING_BASED, + message_bus=MessageBusConfig(backend=MessageBusBackend.KAFKA), + ) + restored = CommunicationConfig.model_validate_json(cfg.model_dump_json()) + assert restored == cfg + + def test_model_dump_enum_values(self) -> None: + cfg = CommunicationConfig() + dumped = cfg.model_dump() + assert dumped["default_pattern"] == "hybrid" + assert dumped["message_bus"]["backend"] == "internal" + + def test_factory(self) -> None: + from tests.unit.communication.conftest import CommunicationConfigFactory + + cfg = CommunicationConfigFactory.build() + assert isinstance(cfg, CommunicationConfig) + + +@pytest.mark.unit +class TestCommunicationConfigFixtures: + def test_sample_communication_config( + self, sample_communication_config: CommunicationConfig + ) -> None: + expected = CommunicationPattern.HYBRID + assert sample_communication_config.default_pattern is expected + + def test_sample_meeting_type(self, sample_meeting_type: MeetingTypeConfig) -> None: + assert sample_meeting_type.name == "daily_standup" + assert sample_meeting_type.frequency == "per_sprint_day" diff --git a/tests/unit/communication/test_enums.py b/tests/unit/communication/test_enums.py new file mode 100644 index 0000000000..cd85c91ace --- /dev/null +++ b/tests/unit/communication/test_enums.py @@ -0,0 +1,97 @@ +"""Tests for the communication domain enumerations.""" + +import pytest + +from ai_company.communication.enums import ( + AttachmentType, + ChannelType, + CommunicationPattern, + MessageBusBackend, + MessagePriority, + MessageType, +) + +pytestmark = pytest.mark.timeout(30) + + +@pytest.mark.unit +class TestMessageType: + def test_member_count(self) -> None: + assert len(MessageType) == 8 + + def test_values(self) -> None: + assert MessageType.TASK_UPDATE == "task_update" + assert MessageType.QUESTION == "question" + assert MessageType.ANNOUNCEMENT == "announcement" + assert MessageType.REVIEW_REQUEST == "review_request" + assert MessageType.APPROVAL == "approval" + assert MessageType.DELEGATION == "delegation" + assert MessageType.STATUS_REPORT == "status_report" + assert MessageType.ESCALATION == "escalation" + + def test_string_identity(self) -> None: + assert str(MessageType.TASK_UPDATE) == "task_update" + + +@pytest.mark.unit +class TestMessagePriority: + def test_member_count(self) -> None: + assert len(MessagePriority) == 4 + + def test_values(self) -> None: + assert MessagePriority.LOW == "low" + assert MessagePriority.NORMAL == "normal" + assert MessagePriority.HIGH == "high" + assert MessagePriority.URGENT == "urgent" + + def test_normal_not_medium(self) -> None: + """Message priority uses 'normal', not 'medium' like task Priority.""" + member_values = {m.value for m in MessagePriority} + assert "normal" in member_values + assert "medium" not in member_values + + +@pytest.mark.unit +class TestChannelType: + def test_member_count(self) -> None: + assert len(ChannelType) == 3 + + def test_values(self) -> None: + assert ChannelType.TOPIC == "topic" + assert ChannelType.DIRECT == "direct" + assert ChannelType.BROADCAST == "broadcast" + + +@pytest.mark.unit +class TestAttachmentType: + def test_member_count(self) -> None: + assert len(AttachmentType) == 3 + + def test_values(self) -> None: + assert AttachmentType.ARTIFACT == "artifact" + assert AttachmentType.FILE == "file" + assert AttachmentType.LINK == "link" + + +@pytest.mark.unit +class TestCommunicationPattern: + def test_member_count(self) -> None: + assert len(CommunicationPattern) == 4 + + def test_values(self) -> None: + assert CommunicationPattern.EVENT_DRIVEN == "event_driven" + assert CommunicationPattern.HIERARCHICAL == "hierarchical" + assert CommunicationPattern.MEETING_BASED == "meeting_based" + assert CommunicationPattern.HYBRID == "hybrid" + + +@pytest.mark.unit +class TestMessageBusBackend: + def test_member_count(self) -> None: + assert len(MessageBusBackend) == 4 + + def test_values(self) -> None: + assert MessageBusBackend.INTERNAL == "internal" + assert MessageBusBackend.REDIS == "redis" + assert MessageBusBackend.RABBITMQ == "rabbitmq" + assert MessageBusBackend.KAFKA == "kafka" diff --git a/tests/unit/communication/test_message.py b/tests/unit/communication/test_message.py new file mode 100644 index 0000000000..5d14533bc9 --- /dev/null +++ b/tests/unit/communication/test_message.py @@ -0,0 +1,366 @@ +"""Tests for the Attachment, MessageMetadata, and Message domain models.""" + +from datetime import UTC, datetime +from uuid import UUID + +import pytest +from pydantic import ValidationError + +from ai_company.communication.enums import ( + AttachmentType, + MessagePriority, + MessageType, +) +from ai_company.communication.message import Attachment, Message, MessageMetadata + +pytestmark = pytest.mark.timeout(30) + +# ── Helpers ────────────────────────────────────────────────────── + +_MESSAGE_KWARGS: dict[str, object] = { + "timestamp": datetime(2026, 2, 27, 10, 30, tzinfo=UTC), + "sender": "sarah_chen", + "to": "engineering", + "type": MessageType.TASK_UPDATE, + "channel": "#backend", + "content": "PR ready for review.", +} + + +def _make_message(**overrides: object) -> Message: + """Create a Message with sensible defaults, applying overrides.""" + kwargs = {**_MESSAGE_KWARGS, **overrides} + return Message(**kwargs) # type: ignore[arg-type] + + +# ── Attachment ────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestAttachment: + def test_construction(self) -> None: + att = Attachment(type=AttachmentType.ARTIFACT, ref="pr-42") + assert att.type is AttachmentType.ARTIFACT + assert att.ref == "pr-42" + + def test_empty_ref_rejected(self) -> None: + with pytest.raises(ValidationError): + Attachment(type=AttachmentType.FILE, ref="") + + def test_whitespace_ref_rejected(self) -> None: + with pytest.raises(ValidationError, match="ref must not be whitespace-only"): + Attachment(type=AttachmentType.FILE, ref=" ") + + def test_frozen(self) -> None: + att = Attachment(type=AttachmentType.LINK, ref="https://example.com") + with pytest.raises(ValidationError): + att.ref = "other" # type: ignore[misc] + + def test_json_roundtrip(self) -> None: + att = Attachment(type=AttachmentType.ARTIFACT, ref="pr-42") + restored = Attachment.model_validate_json(att.model_dump_json()) + assert restored == att + + def test_model_dump(self) -> None: + att = Attachment(type=AttachmentType.ARTIFACT, ref="pr-42") + dumped = att.model_dump() + assert dumped["type"] == "artifact" + assert dumped["ref"] == "pr-42" + + def test_factory(self) -> None: + from tests.unit.communication.conftest import AttachmentFactory + + att = AttachmentFactory.build() + assert isinstance(att, Attachment) + + def test_model_copy(self) -> None: + original = Attachment(type=AttachmentType.FILE, ref="readme.md") + updated = original.model_copy(update={"ref": "changelog.md"}) + assert updated.ref == "changelog.md" + assert original.ref == "readme.md" + + +# ── MessageMetadata ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestMessageMetadataDefaults: + def test_defaults(self) -> None: + meta = MessageMetadata() + assert meta.task_id is None + assert meta.project_id is None + assert meta.tokens_used is None + assert meta.cost_usd is None + assert meta.extra == () + + def test_custom_values(self) -> None: + meta = MessageMetadata( + task_id="task-1", + project_id="proj-1", + tokens_used=500, + cost_usd=0.05, + extra=(("key1", "val1"),), + ) + assert meta.task_id == "task-1" + assert meta.project_id == "proj-1" + assert meta.tokens_used == 500 + assert meta.cost_usd == 0.05 + assert meta.extra == (("key1", "val1"),) + + +@pytest.mark.unit +class TestMessageMetadataValidation: + def test_whitespace_task_id_rejected(self) -> None: + with pytest.raises( + ValidationError, match="task_id must not be whitespace-only" + ): + MessageMetadata(task_id=" ") + + def test_whitespace_project_id_rejected(self) -> None: + with pytest.raises( + ValidationError, match="project_id must not be whitespace-only" + ): + MessageMetadata(project_id=" ") + + def test_negative_tokens_rejected(self) -> None: + with pytest.raises(ValidationError): + MessageMetadata(tokens_used=-1) + + def test_negative_cost_rejected(self) -> None: + with pytest.raises(ValidationError): + MessageMetadata(cost_usd=-0.01) + + def test_zero_tokens_allowed(self) -> None: + meta = MessageMetadata(tokens_used=0) + assert meta.tokens_used == 0 + + def test_zero_cost_allowed(self) -> None: + meta = MessageMetadata(cost_usd=0.0) + assert meta.cost_usd == 0.0 + + def test_blank_extra_key_rejected(self) -> None: + with pytest.raises(ValidationError, match="extra keys must not be blank"): + MessageMetadata(extra=((" ", "val"),)) + + def test_duplicate_extra_keys_rejected(self) -> None: + with pytest.raises(ValidationError, match="Duplicate keys in extra"): + MessageMetadata(extra=(("k", "v1"), ("k", "v2"))) + + +@pytest.mark.unit +class TestMessageMetadataImmutability: + def test_frozen(self) -> None: + meta = MessageMetadata(task_id="task-1") + with pytest.raises(ValidationError): + meta.task_id = "task-2" # type: ignore[misc] + + def test_model_copy(self) -> None: + original = MessageMetadata(task_id="task-1") + updated = original.model_copy(update={"task_id": "task-2"}) + assert updated.task_id == "task-2" + assert original.task_id == "task-1" + + +@pytest.mark.unit +class TestMessageMetadataSerialization: + def test_json_roundtrip(self) -> None: + meta = MessageMetadata( + task_id="task-1", + project_id="proj-1", + tokens_used=100, + cost_usd=0.01, + extra=(("env", "prod"),), + ) + restored = MessageMetadata.model_validate_json(meta.model_dump_json()) + assert restored == meta + + def test_factory(self) -> None: + from tests.unit.communication.conftest import MessageMetadataFactory + + meta = MessageMetadataFactory.build() + assert isinstance(meta, MessageMetadata) + + +# ── Message ───────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestMessageConstruction: + def test_minimal_valid(self) -> None: + msg = _make_message() + assert isinstance(msg.id, UUID) + assert msg.sender == "sarah_chen" + assert msg.to == "engineering" + assert msg.type is MessageType.TASK_UPDATE + assert msg.channel == "#backend" + assert msg.content == "PR ready for review." + + def test_default_values(self) -> None: + msg = _make_message() + assert msg.priority is MessagePriority.NORMAL + assert msg.attachments == () + assert isinstance(msg.metadata, MessageMetadata) + + def test_all_fields_set(self) -> None: + meta = MessageMetadata(task_id="task-1") + att = Attachment(type=AttachmentType.ARTIFACT, ref="pr-42") + msg = Message( + timestamp=datetime(2026, 2, 27, 10, 30, tzinfo=UTC), + sender="sarah_chen", + to="engineering", + type=MessageType.REVIEW_REQUEST, + priority=MessagePriority.HIGH, + channel="#code-review", + content="Please review PR-42.", + attachments=(att,), + metadata=meta, + ) + assert msg.priority is MessagePriority.HIGH + assert len(msg.attachments) == 1 + assert msg.metadata.task_id == "task-1" + + +@pytest.mark.unit +class TestMessageAlias: + def test_alias_from_parsing(self) -> None: + """Parse JSON with 'from' key (DESIGN_SPEC 5.3 format).""" + data = { + "timestamp": "2026-02-27T10:30:00Z", + "from": "sarah_chen", + "to": "engineering", + "type": "task_update", + "channel": "#backend", + "content": "Hello.", + } + msg = Message.model_validate(data) + assert msg.sender == "sarah_chen" + + def test_populate_by_name(self) -> None: + """Parse JSON with 'sender' key (populate_by_name=True).""" + data = { + "timestamp": "2026-02-27T10:30:00Z", + "sender": "sarah_chen", + "to": "engineering", + "type": "task_update", + "channel": "#backend", + "content": "Hello.", + } + msg = Message.model_validate(data) + assert msg.sender == "sarah_chen" + + def test_dump_by_alias(self) -> None: + """model_dump(by_alias=True) outputs 'from' key.""" + msg = _make_message() + dumped = msg.model_dump(by_alias=True) + assert "from" in dumped + assert dumped["from"] == "sarah_chen" + + def test_dump_by_name(self) -> None: + """model_dump() outputs 'sender' key.""" + msg = _make_message() + dumped = msg.model_dump() + assert "sender" in dumped + assert dumped["sender"] == "sarah_chen" + + +@pytest.mark.unit +class TestMessageStringValidation: + def test_empty_sender_rejected(self) -> None: + with pytest.raises(ValidationError): + _make_message(sender="") + + def test_whitespace_sender_rejected(self) -> None: + with pytest.raises(ValidationError, match="sender must not be whitespace-only"): + _make_message(sender=" ") + + def test_empty_to_rejected(self) -> None: + with pytest.raises(ValidationError): + _make_message(to="") + + def test_whitespace_to_rejected(self) -> None: + with pytest.raises(ValidationError, match="to must not be whitespace-only"): + _make_message(to=" ") + + def test_empty_channel_rejected(self) -> None: + with pytest.raises(ValidationError): + _make_message(channel="") + + def test_whitespace_channel_rejected(self) -> None: + with pytest.raises( + ValidationError, match="channel must not be whitespace-only" + ): + _make_message(channel=" ") + + def test_empty_content_rejected(self) -> None: + with pytest.raises(ValidationError): + _make_message(content="") + + def test_whitespace_content_rejected(self) -> None: + with pytest.raises( + ValidationError, match="content must not be whitespace-only" + ): + _make_message(content=" ") + + +@pytest.mark.unit +class TestMessageImmutability: + def test_frozen(self) -> None: + msg = _make_message() + with pytest.raises(ValidationError): + msg.content = "new" # type: ignore[misc] + + def test_model_copy(self) -> None: + original = _make_message() + updated = original.model_copy(update={"content": "Updated content."}) + assert updated.content == "Updated content." + assert original.content == "PR ready for review." + + +@pytest.mark.unit +class TestMessageSerialization: + def test_json_roundtrip(self) -> None: + msg = _make_message( + attachments=(Attachment(type=AttachmentType.ARTIFACT, ref="pr-42"),), + metadata=MessageMetadata(task_id="task-1"), + ) + json_str = msg.model_dump_json() + restored = Message.model_validate_json(json_str) + assert restored.id == msg.id + assert restored.sender == msg.sender + assert restored.type is msg.type + assert restored.attachments == msg.attachments + assert restored.metadata == msg.metadata + + def test_model_dump_enum_values(self) -> None: + msg = _make_message() + dumped = msg.model_dump() + assert dumped["type"] == "task_update" + assert dumped["priority"] == "normal" + + +@pytest.mark.unit +class TestMessageFactory: + def test_factory(self) -> None: + from tests.unit.communication.conftest import MessageFactory + + msg = MessageFactory.build() + assert isinstance(msg, Message) + assert isinstance(msg.id, UUID) + assert isinstance(msg.type, MessageType) + + +@pytest.mark.unit +class TestMessageFixtures: + def test_sample_message(self, sample_message: Message) -> None: + assert sample_message.sender == "sarah_chen" + assert sample_message.to == "engineering" + assert sample_message.type is MessageType.TASK_UPDATE + assert len(sample_message.attachments) == 1 + + def test_sample_attachment(self, sample_attachment: Attachment) -> None: + assert sample_attachment.type is AttachmentType.ARTIFACT + assert sample_attachment.ref == "pr-42" + + def test_sample_metadata(self, sample_metadata: MessageMetadata) -> None: + assert sample_metadata.task_id == "task-123" + assert sample_metadata.tokens_used == 1200 From 95d5bfc7e095f1326a458ea0bd0543b7bffbf246 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Sat, 28 Feb 2026 14:06:13 +0100 Subject: [PATCH 2/3] fix: address 18 PR review items from local agents and external reviewers - Add MeetingTypeConfig validators: whitespace-only name/frequency/trigger, blank/duplicate participants (Critical: 3 items) - Replace ancestry_tracking bool+validator with Literal[True] (Major) - Use AwareDatetime for Message.timestamp to enforce timezone (Major) - Extract NotBlankStr annotated type to core/types.py, eliminating 5 repeated model validators across message.py, channel.py, config.py (Major: DRY) - Fix MessageMetadata/Message docstrings for spec accuracy (Major) - Expand ChannelType and AttachmentType docstrings (Medium) - Fix noqa comment consistency on imports (Medium) - Add 13 new tests: __all__ exports, JSON alias roundtrip, empty string boundaries, unique message IDs, MeetingTypeConfig validation (Medium) - Replace brittle len(channels)==7 with _DEFAULT_CHANNELS comparison (Minor) --- src/ai_company/communication/channel.py | 13 ++--- src/ai_company/communication/config.py | 38 ++++++++------ src/ai_company/communication/enums.py | 16 +++++- src/ai_company/communication/message.py | 67 ++++++++---------------- src/ai_company/core/__init__.py | 2 + src/ai_company/core/types.py | 21 ++++++++ tests/unit/communication/test_channel.py | 8 ++- tests/unit/communication/test_config.py | 52 ++++++++++++++++-- tests/unit/communication/test_enums.py | 9 ++++ tests/unit/communication/test_message.py | 49 +++++++++++------ 10 files changed, 182 insertions(+), 93 deletions(-) create mode 100644 src/ai_company/core/types.py diff --git a/src/ai_company/communication/channel.py b/src/ai_company/communication/channel.py index 24e8846619..882a62becb 100644 --- a/src/ai_company/communication/channel.py +++ b/src/ai_company/communication/channel.py @@ -6,6 +6,9 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator from ai_company.communication.enums import ChannelType +from ai_company.core.types import ( + NotBlankStr, # noqa: TC001 -- required at runtime by Pydantic +) class Channel(BaseModel): @@ -19,7 +22,7 @@ class Channel(BaseModel): model_config = ConfigDict(frozen=True) - name: str = Field(min_length=1, description="Channel name") + name: NotBlankStr = Field(description="Channel name") type: ChannelType = Field( default=ChannelType.TOPIC, description="Channel delivery semantics", @@ -29,14 +32,6 @@ class Channel(BaseModel): description="Agent IDs subscribed to this channel", ) - @model_validator(mode="after") - def _validate_name_not_blank(self) -> Self: - """Ensure name is not whitespace-only.""" - if not self.name.strip(): - msg = "name must not be whitespace-only" - raise ValueError(msg) - return self - @model_validator(mode="after") def _validate_subscribers(self) -> Self: """Ensure subscriber entries are non-blank and unique.""" diff --git a/src/ai_company/communication/config.py b/src/ai_company/communication/config.py index 296a45f962..d00317041f 100644 --- a/src/ai_company/communication/config.py +++ b/src/ai_company/communication/config.py @@ -1,7 +1,7 @@ """Communication configuration models (DESIGN_SPEC Sections 5.4, 5.5).""" from collections import Counter -from typing import Self +from typing import Literal, Self from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -9,6 +9,9 @@ CommunicationPattern, MessageBusBackend, ) +from ai_company.core.types import ( + NotBlankStr, # noqa: TC001 -- required at runtime by Pydantic +) # Default channels from DESIGN_SPEC Section 5.4. _DEFAULT_CHANNELS: tuple[str, ...] = ( @@ -73,15 +76,13 @@ class MeetingTypeConfig(BaseModel): model_config = ConfigDict(frozen=True) - name: str = Field(min_length=1, description="Meeting type name") - frequency: str | None = Field( + name: NotBlankStr = Field(description="Meeting type name") + frequency: NotBlankStr | None = Field( default=None, - min_length=1, description="Recurrence schedule", ) - trigger: str | None = Field( + trigger: NotBlankStr | None = Field( default=None, - min_length=1, description="Event trigger", ) participants: tuple[str, ...] = Field( @@ -105,6 +106,19 @@ def _validate_frequency_or_trigger(self) -> Self: raise ValueError(msg) return self + @model_validator(mode="after") + def _validate_participants(self) -> Self: + """Ensure participant entries are non-blank and unique.""" + for p in self.participants: + if not p.strip(): + msg = "Empty or whitespace-only entry in participants" + raise ValueError(msg) + if len(self.participants) != len(set(self.participants)): + dupes = sorted(p for p, c in Counter(self.participants).items() if c > 1) + msg = f"Duplicate entries in participants: {dupes}" + raise ValueError(msg) + return self + class MeetingsConfig(BaseModel): """Meetings subsystem configuration. @@ -239,19 +253,11 @@ class LoopPreventionConfig(BaseModel): default_factory=CircuitBreakerConfig, description="Circuit breaker settings", ) - ancestry_tracking: bool = Field( + ancestry_tracking: Literal[True] = Field( default=True, - description="Task ancestry tracking (always on)", + description="Task ancestry tracking (always on, not configurable)", ) - @model_validator(mode="after") - def _validate_ancestry_tracking(self) -> Self: - """Ancestry tracking must always be enabled (spec: not configurable).""" - if not self.ancestry_tracking: - msg = "ancestry_tracking must be True (always on, not configurable)" - raise ValueError(msg) - return self - class CommunicationConfig(BaseModel): """Top-level communication configuration. diff --git a/src/ai_company/communication/enums.py b/src/ai_company/communication/enums.py index f6f3fee6b3..2bc1eac9bc 100644 --- a/src/ai_company/communication/enums.py +++ b/src/ai_company/communication/enums.py @@ -33,7 +33,13 @@ class MessagePriority(StrEnum): class ChannelType(StrEnum): - """Channel delivery semantics.""" + """Channel delivery semantics. + + Members: + TOPIC: Publish-subscribe delivery to all subscribers. + DIRECT: Point-to-point delivery to a single recipient. + BROADCAST: Delivery to all agents regardless of subscription. + """ TOPIC = "topic" DIRECT = "direct" @@ -41,7 +47,13 @@ class ChannelType(StrEnum): class AttachmentType(StrEnum): - """Type of message attachment.""" + """Type of message attachment. + + Members: + ARTIFACT: Reference to a domain artifact (e.g. PR, build output). + FILE: Reference to a file path. + LINK: Reference to a URL. + """ ARTIFACT = "artifact" FILE = "file" diff --git a/src/ai_company/communication/message.py b/src/ai_company/communication/message.py index 2b67b514df..2a739dd043 100644 --- a/src/ai_company/communication/message.py +++ b/src/ai_company/communication/message.py @@ -1,17 +1,19 @@ """Message domain models (DESIGN_SPEC Section 5.3).""" from collections import Counter -from datetime import datetime # noqa: TC003 from typing import Self from uuid import UUID, uuid4 -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, model_validator from ai_company.communication.enums import ( AttachmentType, MessagePriority, MessageType, ) +from ai_company.core.types import ( + NotBlankStr, # noqa: TC001 -- required at runtime by Pydantic +) class Attachment(BaseModel): @@ -25,38 +27,31 @@ class Attachment(BaseModel): model_config = ConfigDict(frozen=True) type: AttachmentType = Field(description="Kind of attachment") - ref: str = Field(min_length=1, description="Reference identifier") - - @model_validator(mode="after") - def _validate_ref_not_blank(self) -> Self: - """Ensure ref is not whitespace-only.""" - if not self.ref.strip(): - msg = "ref must not be whitespace-only" - raise ValueError(msg) - return self + ref: NotBlankStr = Field(description="Reference identifier") class MessageMetadata(BaseModel): """Optional metadata carried with a message. + Extends DESIGN_SPEC Section 5.3 metadata with an additional ``extra`` + field for arbitrary key-value pairs. + Attributes: task_id: Related task identifier. project_id: Related project identifier. tokens_used: LLM tokens consumed producing the message. cost_usd: Estimated USD cost of the message. - extra: Immutable key-value pairs for arbitrary metadata. + extra: Immutable key-value pairs for arbitrary metadata (extension). """ model_config = ConfigDict(frozen=True) - task_id: str | None = Field( + task_id: NotBlankStr | None = Field( default=None, - min_length=1, description="Related task identifier", ) - project_id: str | None = Field( + project_id: NotBlankStr | None = Field( default=None, - min_length=1, description="Related project identifier", ) tokens_used: int | None = Field( @@ -74,16 +69,6 @@ class MessageMetadata(BaseModel): description="Immutable key-value pairs for arbitrary metadata", ) - @model_validator(mode="after") - def _validate_optional_strings(self) -> Self: - """Ensure optional string fields are not whitespace-only.""" - for field_name in ("task_id", "project_id"): - value = getattr(self, field_name) - if value is not None and not value.strip(): - msg = f"{field_name} must not be whitespace-only" - raise ValueError(msg) - return self - @model_validator(mode="after") def _validate_extra(self) -> Self: """Ensure extra keys are non-blank and unique.""" @@ -103,12 +88,13 @@ def _validate_extra(self) -> Self: class Message(BaseModel): """An inter-agent message. - Field schema matches DESIGN_SPEC Section 5.3. The ``sender`` field - is aliased to ``"from"`` for JSON compatibility with the spec format. + Field schema is based on DESIGN_SPEC Section 5.3 with typed refinements. + The ``sender`` field is aliased to ``"from"`` for JSON compatibility with + the spec format. Attributes: id: Unique message identifier. - timestamp: When the message was created. + timestamp: When the message was created (must be timezone-aware). sender: Agent ID of the sender (aliased to ``"from"`` in JSON). to: Recipient agent or channel identifier. type: Message type classification. @@ -125,23 +111,23 @@ class Message(BaseModel): default_factory=uuid4, description="Unique message identifier", ) - timestamp: datetime = Field(description="When the message was created") - sender: str = Field( - min_length=1, + timestamp: AwareDatetime = Field( + description="When the message was created (must be timezone-aware)", + ) + sender: NotBlankStr = Field( alias="from", description="Sender agent ID", ) - to: str = Field(min_length=1, description="Recipient agent or channel") + to: NotBlankStr = Field(description="Recipient agent or channel") type: MessageType = Field(description="Message type classification") priority: MessagePriority = Field( default=MessagePriority.NORMAL, description="Message priority level", ) - channel: str = Field( - min_length=1, + channel: NotBlankStr = Field( description="Channel the message is sent through", ) - content: str = Field(min_length=1, description="Message body text") + content: NotBlankStr = Field(description="Message body text") attachments: tuple[Attachment, ...] = Field( default=(), description="Attached references", @@ -150,12 +136,3 @@ class Message(BaseModel): default_factory=MessageMetadata, description="Optional message metadata", ) - - @model_validator(mode="after") - def _validate_strings_not_blank(self) -> Self: - """Ensure required string fields are not whitespace-only.""" - for field_name in ("sender", "to", "channel", "content"): - if not getattr(self, field_name).strip(): - msg = f"{field_name} must not be whitespace-only" - raise ValueError(msg) - return self diff --git a/src/ai_company/core/__init__.py b/src/ai_company/core/__init__.py index cb54848ba9..ab2815fece 100644 --- a/src/ai_company/core/__init__.py +++ b/src/ai_company/core/__init__.py @@ -50,6 +50,7 @@ ) from ai_company.core.task import AcceptanceCriterion, Task from ai_company.core.task_transitions import VALID_TRANSITIONS, validate_transition +from ai_company.core.types import NotBlankStr __all__ = [ "BUILTIN_ROLES", @@ -75,6 +76,7 @@ "MemoryConfig", "MemoryType", "ModelConfig", + "NotBlankStr", "PersonalityConfig", "Priority", "ProficiencyLevel", diff --git a/src/ai_company/core/types.py b/src/ai_company/core/types.py new file mode 100644 index 0000000000..7abff16c00 --- /dev/null +++ b/src/ai_company/core/types.py @@ -0,0 +1,21 @@ +"""Reusable Pydantic type annotations.""" + +from typing import Annotated + +from pydantic import AfterValidator, StringConstraints + + +def _check_not_whitespace(value: str) -> str: + """Reject whitespace-only strings.""" + if not value.strip(): + msg = "must not be whitespace-only" + raise ValueError(msg) + return value + + +NotBlankStr = Annotated[ + str, + StringConstraints(min_length=1), + AfterValidator(_check_not_whitespace), +] +"""A string that must be non-empty and not consist solely of whitespace.""" diff --git a/tests/unit/communication/test_channel.py b/tests/unit/communication/test_channel.py index df17111652..3915503d23 100644 --- a/tests/unit/communication/test_channel.py +++ b/tests/unit/communication/test_channel.py @@ -41,9 +41,15 @@ def test_empty_name_rejected(self) -> None: Channel(name="") def test_whitespace_name_rejected(self) -> None: - with pytest.raises(ValidationError, match="name must not be whitespace-only"): + with pytest.raises(ValidationError, match="whitespace-only"): Channel(name=" ") + def test_empty_subscriber_rejected(self) -> None: + with pytest.raises( + ValidationError, match="Empty or whitespace-only entry in subscribers" + ): + Channel(name="#test", subscribers=("agent-a", "")) + def test_whitespace_subscriber_rejected(self) -> None: with pytest.raises( ValidationError, match="Empty or whitespace-only entry in subscribers" diff --git a/tests/unit/communication/test_config.py b/tests/unit/communication/test_config.py index e7f82acabd..06d19b188b 100644 --- a/tests/unit/communication/test_config.py +++ b/tests/unit/communication/test_config.py @@ -4,6 +4,7 @@ from pydantic import ValidationError from ai_company.communication.config import ( + _DEFAULT_CHANNELS, CircuitBreakerConfig, CommunicationConfig, HierarchyConfig, @@ -29,10 +30,7 @@ class TestMessageBusConfigDefaults: def test_defaults(self) -> None: cfg = MessageBusConfig() assert cfg.backend is MessageBusBackend.INTERNAL - assert len(cfg.channels) == 7 - assert "#all-hands" in cfg.channels - assert "#engineering" in cfg.channels - assert "#watercooler" in cfg.channels + assert cfg.channels == _DEFAULT_CHANNELS def test_custom_values(self) -> None: cfg = MessageBusConfig( @@ -45,6 +43,12 @@ def test_custom_values(self) -> None: @pytest.mark.unit class TestMessageBusConfigValidation: + def test_empty_channel_rejected(self) -> None: + with pytest.raises( + ValidationError, match="Empty or whitespace-only entry in channels" + ): + MessageBusConfig(channels=("#valid", "")) + def test_whitespace_channel_rejected(self) -> None: with pytest.raises( ValidationError, match="Empty or whitespace-only entry in channels" @@ -140,6 +144,44 @@ def test_empty_name_rejected(self) -> None: with pytest.raises(ValidationError): MeetingTypeConfig(name="", frequency="daily") + def test_whitespace_name_rejected(self) -> None: + with pytest.raises(ValidationError, match="whitespace-only"): + MeetingTypeConfig(name=" ", frequency="daily") + + def test_whitespace_frequency_rejected(self) -> None: + with pytest.raises(ValidationError, match="whitespace-only"): + MeetingTypeConfig(name="standup", frequency=" ") + + def test_whitespace_trigger_rejected(self) -> None: + with pytest.raises(ValidationError, match="whitespace-only"): + MeetingTypeConfig(name="review", trigger=" ") + + def test_whitespace_participant_rejected(self) -> None: + with pytest.raises( + ValidationError, + match="Empty or whitespace-only entry in participants", + ): + MeetingTypeConfig( + name="standup", frequency="daily", participants=("eng", " ") + ) + + def test_empty_participant_rejected(self) -> None: + with pytest.raises( + ValidationError, + match="Empty or whitespace-only entry in participants", + ): + MeetingTypeConfig( + name="standup", frequency="daily", participants=("eng", "") + ) + + def test_duplicate_participants_rejected(self) -> None: + with pytest.raises(ValidationError, match="Duplicate entries in participants"): + MeetingTypeConfig( + name="standup", + frequency="daily", + participants=("eng", "qa", "eng"), + ) + def test_zero_duration_rejected(self) -> None: with pytest.raises(ValidationError): MeetingTypeConfig(name="bad", frequency="daily", duration_tokens=0) @@ -391,7 +433,7 @@ class TestLoopPreventionConfigValidation: def test_ancestry_tracking_false_rejected(self) -> None: with pytest.raises( ValidationError, - match="ancestry_tracking must be True", + match="Input should be True", ): LoopPreventionConfig(ancestry_tracking=False) diff --git a/tests/unit/communication/test_enums.py b/tests/unit/communication/test_enums.py index cd85c91ace..e68778b925 100644 --- a/tests/unit/communication/test_enums.py +++ b/tests/unit/communication/test_enums.py @@ -85,6 +85,15 @@ def test_values(self) -> None: assert CommunicationPattern.HYBRID == "hybrid" +@pytest.mark.unit +class TestCommunicationExports: + def test_all_exports_importable(self) -> None: + import ai_company.communication as comm_module + + for name in comm_module.__all__: + assert hasattr(comm_module, name), f"{name} in __all__ but not importable" + + @pytest.mark.unit class TestMessageBusBackend: def test_member_count(self) -> None: diff --git a/tests/unit/communication/test_message.py b/tests/unit/communication/test_message.py index 5d14533bc9..4611ac57d2 100644 --- a/tests/unit/communication/test_message.py +++ b/tests/unit/communication/test_message.py @@ -48,7 +48,7 @@ def test_empty_ref_rejected(self) -> None: Attachment(type=AttachmentType.FILE, ref="") def test_whitespace_ref_rejected(self) -> None: - with pytest.raises(ValidationError, match="ref must not be whitespace-only"): + with pytest.raises(ValidationError, match="whitespace-only"): Attachment(type=AttachmentType.FILE, ref=" ") def test_frozen(self) -> None: @@ -110,16 +110,20 @@ def test_custom_values(self) -> None: @pytest.mark.unit class TestMessageMetadataValidation: + def test_empty_task_id_rejected(self) -> None: + with pytest.raises(ValidationError): + MessageMetadata(task_id="") + def test_whitespace_task_id_rejected(self) -> None: - with pytest.raises( - ValidationError, match="task_id must not be whitespace-only" - ): + with pytest.raises(ValidationError, match="whitespace-only"): MessageMetadata(task_id=" ") + def test_empty_project_id_rejected(self) -> None: + with pytest.raises(ValidationError): + MessageMetadata(project_id="") + def test_whitespace_project_id_rejected(self) -> None: - with pytest.raises( - ValidationError, match="project_id must not be whitespace-only" - ): + with pytest.raises(ValidationError, match="whitespace-only"): MessageMetadata(project_id=" ") def test_negative_tokens_rejected(self) -> None: @@ -220,6 +224,14 @@ def test_all_fields_set(self) -> None: assert msg.metadata.task_id == "task-1" +@pytest.mark.unit +class TestMessageUniqueIds: + def test_unique_ids(self) -> None: + msg1 = _make_message() + msg2 = _make_message() + assert msg1.id != msg2.id + + @pytest.mark.unit class TestMessageAlias: def test_alias_from_parsing(self) -> None: @@ -263,6 +275,17 @@ def test_dump_by_name(self) -> None: assert dumped["sender"] == "sarah_chen" +@pytest.mark.unit +class TestMessageAliasRoundtrip: + def test_json_roundtrip_with_alias(self) -> None: + """Ensure JSON with 'from' key (DESIGN_SPEC 5.3 format) round-trips.""" + msg = _make_message() + json_str = msg.model_dump_json(by_alias=True) + assert '"from"' in json_str + restored = Message.model_validate_json(json_str) + assert restored == msg + + @pytest.mark.unit class TestMessageStringValidation: def test_empty_sender_rejected(self) -> None: @@ -270,7 +293,7 @@ def test_empty_sender_rejected(self) -> None: _make_message(sender="") def test_whitespace_sender_rejected(self) -> None: - with pytest.raises(ValidationError, match="sender must not be whitespace-only"): + with pytest.raises(ValidationError, match="whitespace-only"): _make_message(sender=" ") def test_empty_to_rejected(self) -> None: @@ -278,7 +301,7 @@ def test_empty_to_rejected(self) -> None: _make_message(to="") def test_whitespace_to_rejected(self) -> None: - with pytest.raises(ValidationError, match="to must not be whitespace-only"): + with pytest.raises(ValidationError, match="whitespace-only"): _make_message(to=" ") def test_empty_channel_rejected(self) -> None: @@ -286,9 +309,7 @@ def test_empty_channel_rejected(self) -> None: _make_message(channel="") def test_whitespace_channel_rejected(self) -> None: - with pytest.raises( - ValidationError, match="channel must not be whitespace-only" - ): + with pytest.raises(ValidationError, match="whitespace-only"): _make_message(channel=" ") def test_empty_content_rejected(self) -> None: @@ -296,9 +317,7 @@ def test_empty_content_rejected(self) -> None: _make_message(content="") def test_whitespace_content_rejected(self) -> None: - with pytest.raises( - ValidationError, match="content must not be whitespace-only" - ): + with pytest.raises(ValidationError, match="whitespace-only"): _make_message(content=" ") From 3d193d9a62b5e344b3c30cd2a34ec736bc24a457 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Sat, 28 Feb 2026 14:40:36 +0100 Subject: [PATCH 3/3] refactor: extract shared validate_non_blank_unique_strings helper DRY up repeated non-blank/unique tuple-of-strings validation across MessageBusConfig, MeetingTypeConfig, and Channel into a single reusable helper in core.types. Co-Authored-By: Claude Opus 4.6 --- src/ai_company/communication/channel.py | 13 +++---------- src/ai_company/communication/config.py | 21 ++++----------------- src/ai_company/core/__init__.py | 3 ++- src/ai_company/core/types.py | 23 ++++++++++++++++++++++- 4 files changed, 31 insertions(+), 29 deletions(-) diff --git a/src/ai_company/communication/channel.py b/src/ai_company/communication/channel.py index 882a62becb..d5ae0dcc0a 100644 --- a/src/ai_company/communication/channel.py +++ b/src/ai_company/communication/channel.py @@ -1,13 +1,13 @@ """Channel domain model.""" -from collections import Counter from typing import Self from pydantic import BaseModel, ConfigDict, Field, model_validator from ai_company.communication.enums import ChannelType from ai_company.core.types import ( - NotBlankStr, # noqa: TC001 -- required at runtime by Pydantic + NotBlankStr, + validate_non_blank_unique_strings, ) @@ -35,12 +35,5 @@ class Channel(BaseModel): @model_validator(mode="after") def _validate_subscribers(self) -> Self: """Ensure subscriber entries are non-blank and unique.""" - for sub in self.subscribers: - if not sub.strip(): - msg = "Empty or whitespace-only entry in subscribers" - raise ValueError(msg) - if len(self.subscribers) != len(set(self.subscribers)): - dupes = sorted(s for s, c in Counter(self.subscribers).items() if c > 1) - msg = f"Duplicate entries in subscribers: {dupes}" - raise ValueError(msg) + validate_non_blank_unique_strings(self.subscribers, "subscribers") return self diff --git a/src/ai_company/communication/config.py b/src/ai_company/communication/config.py index d00317041f..05c9885bb0 100644 --- a/src/ai_company/communication/config.py +++ b/src/ai_company/communication/config.py @@ -10,7 +10,8 @@ MessageBusBackend, ) from ai_company.core.types import ( - NotBlankStr, # noqa: TC001 -- required at runtime by Pydantic + NotBlankStr, + validate_non_blank_unique_strings, ) # Default channels from DESIGN_SPEC Section 5.4. @@ -49,14 +50,7 @@ class MessageBusConfig(BaseModel): @model_validator(mode="after") def _validate_channels(self) -> Self: """Ensure channel names are non-blank and unique.""" - for ch in self.channels: - if not ch.strip(): - msg = "Empty or whitespace-only entry in channels" - raise ValueError(msg) - if len(self.channels) != len(set(self.channels)): - dupes = sorted(c for c, n in Counter(self.channels).items() if n > 1) - msg = f"Duplicate entries in channels: {dupes}" - raise ValueError(msg) + validate_non_blank_unique_strings(self.channels, "channels") return self @@ -109,14 +103,7 @@ def _validate_frequency_or_trigger(self) -> Self: @model_validator(mode="after") def _validate_participants(self) -> Self: """Ensure participant entries are non-blank and unique.""" - for p in self.participants: - if not p.strip(): - msg = "Empty or whitespace-only entry in participants" - raise ValueError(msg) - if len(self.participants) != len(set(self.participants)): - dupes = sorted(p for p, c in Counter(self.participants).items() if c > 1) - msg = f"Duplicate entries in participants: {dupes}" - raise ValueError(msg) + validate_non_blank_unique_strings(self.participants, "participants") return self diff --git a/src/ai_company/core/__init__.py b/src/ai_company/core/__init__.py index ab2815fece..18bd4c56f8 100644 --- a/src/ai_company/core/__init__.py +++ b/src/ai_company/core/__init__.py @@ -50,7 +50,7 @@ ) from ai_company.core.task import AcceptanceCriterion, Task from ai_company.core.task_transitions import VALID_TRANSITIONS, validate_transition -from ai_company.core.types import NotBlankStr +from ai_company.core.types import NotBlankStr, validate_non_blank_unique_strings __all__ = [ "BUILTIN_ROLES", @@ -96,5 +96,6 @@ "ToolPermissions", "get_builtin_role", "get_seniority_info", + "validate_non_blank_unique_strings", "validate_transition", ] diff --git a/src/ai_company/core/types.py b/src/ai_company/core/types.py index 7abff16c00..8b2e095aff 100644 --- a/src/ai_company/core/types.py +++ b/src/ai_company/core/types.py @@ -1,5 +1,6 @@ -"""Reusable Pydantic type annotations.""" +"""Reusable Pydantic type annotations and validators.""" +from collections import Counter from typing import Annotated from pydantic import AfterValidator, StringConstraints @@ -19,3 +20,23 @@ def _check_not_whitespace(value: str) -> str: AfterValidator(_check_not_whitespace), ] """A string that must be non-empty and not consist solely of whitespace.""" + + +def validate_non_blank_unique_strings( + values: tuple[str, ...], + field_name: str, +) -> None: + """Validate that every string in *values* is non-blank and unique. + + Raises: + ValueError: If any entry is empty/whitespace-only or if duplicates + are present. + """ + for entry in values: + if not entry.strip(): + msg = f"Empty or whitespace-only entry in {field_name}" + raise ValueError(msg) + if len(values) != len(set(values)): + dupes = sorted(v for v, c in Counter(values).items() if c > 1) + msg = f"Duplicate entries in {field_name}: {dupes}" + raise ValueError(msg)