diff --git a/CLAUDE.md b/CLAUDE.md index d33a782597..d35cdcd16c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -90,7 +90,7 @@ See `web/CLAUDE.md` for the full component inventory, design token rules, and po - **Every module** with business logic MUST have: `from synthorg.observability import get_logger` then `logger = get_logger(__name__)` - **Never** use `import logging` / `logging.getLogger()` / `print()` in application code (exception: `observability/setup.py`, `observability/sinks.py`, `observability/syslog_handler.py`, `observability/http_handler.py`, and `observability/otlp_handler.py` may use stdlib `logging` and `print(..., file=sys.stderr)` for handler construction, bootstrap, and error reporting code that runs before or during logging system configuration) - **Variable name**: always `logger` (not `_logger`, not `log`) -- **Event names**: always use constants from the domain-specific module under `synthorg.observability.events` (e.g., `API_REQUEST_STARTED` from `events.api`, `TOOL_INVOKE_START` from `events.tool`, `GIT_COMMAND_START` from `events.git`, `CONTEXT_BUDGET_FILL_UPDATED`, `CONTEXT_BUDGET_COMPACTION_STARTED`, `CONTEXT_BUDGET_COMPACTION_COMPLETED`, `CONTEXT_BUDGET_COMPACTION_FAILED`, `CONTEXT_BUDGET_COMPACTION_SKIPPED`, `CONTEXT_BUDGET_COMPACTION_FALLBACK`, `CONTEXT_BUDGET_INDICATOR_INJECTED`, `CONTEXT_BUDGET_AGENT_COMPACTION_REQUESTED`, `CONTEXT_BUDGET_EPISTEMIC_MARKERS_PRESERVED` from `events.context_budget`, `BACKUP_STARTED` from `events.backup`, `SETUP_COMPLETED` from `events.setup`, `ROUTING_CANDIDATE_SELECTED` from `events.routing`, `SHIPPING_HTTP_BATCH_SENT` from `events.shipping`, `EVAL_REPORT_COMPUTED` from `events.evaluation`, `PROMPT_PROFILE_SELECTED` from `events.prompt`, `PROCEDURAL_MEMORY_START` from `events.procedural_memory`, `PERF_LLM_JUDGE_STARTED` from `events.performance`, `TASK_ENGINE_OBSERVER_FAILED` from `events.task_engine`, `TASK_ASSIGNMENT_PROJECT_FILTERED` and `TASK_ASSIGNMENT_PROJECT_NO_ELIGIBLE` from `events.task_assignment`, `EXECUTION_SHUTDOWN_IMMEDIATE_CANCEL`, `EXECUTION_SHUTDOWN_TOOL_WAIT`, `EXECUTION_SHUTDOWN_CHECKPOINT_SAVE`, `EXECUTION_SHUTDOWN_CHECKPOINT_FAILED`, and `EXECUTION_PROJECT_VALIDATION_FAILED` from `events.execution`, `WORKFLOW_EXEC_COMPLETED` from `events.workflow_execution`, `BLUEPRINT_INSTANTIATE_START` from `events.blueprint`, `WORKFLOW_DEF_ROLLED_BACK` from `events.workflow_definition`, `WORKFLOW_VERSION_SAVED` from `events.workflow_version`, `MEMORY_FINE_TUNE_STARTED`, `MEMORY_SELF_EDIT_TOOL_EXECUTE`, `MEMORY_SELF_EDIT_CORE_READ`, `MEMORY_SELF_EDIT_CORE_WRITE`, `MEMORY_SELF_EDIT_CORE_WRITE_REJECTED`, `MEMORY_SELF_EDIT_ARCHIVAL_SEARCH`, `MEMORY_SELF_EDIT_ARCHIVAL_WRITE`, `MEMORY_SELF_EDIT_RECALL_READ`, `MEMORY_SELF_EDIT_RECALL_WRITE`, `MEMORY_SELF_EDIT_WRITE_FAILED` from `events.memory`, `REPORTING_GENERATION_STARTED` from `events.reporting`, `RISK_BUDGET_SCORE_COMPUTED` from `events.risk_budget`, `BUDGET_PROJECT_COST_QUERIED`, `BUDGET_PROJECT_RECORDS_QUERIED`, `BUDGET_PROJECT_BUDGET_EXCEEDED`, `BUDGET_PROJECT_ENFORCEMENT_CHECK`, `BUDGET_PROJECT_COST_AGGREGATED`, `BUDGET_PROJECT_COST_AGGREGATION_FAILED`, and `BUDGET_PROJECT_BASELINE_SOURCE` from `events.budget`, `LLM_STRATEGY_SYNTHESIZED` and `DISTILLATION_CAPTURED` from `events.consolidation`, `MEMORY_DIVERSITY_RERANKED`, `MEMORY_DIVERSITY_RERANK_FAILED`, and `MEMORY_REFORMULATION_ROUND` from `events.memory`, `NOTIFICATION_DISPATCHED` and `NOTIFICATION_DISPATCH_FAILED` from `events.notification`, `QUALITY_STEP_CLASSIFIED` from `events.quality`, `HEALTH_TICKET_EMITTED` from `events.health`, `TRAJECTORY_SCORING_START` from `events.trajectory`, `COORD_METRICS_AMDAHL_COMPUTED` from `events.coordination_metrics`, `COORDINATION_STARTED`, `COORDINATION_COMPLETED`, `COORDINATION_FAILED`, `COORDINATION_PHASE_STARTED`, `COORDINATION_PHASE_COMPLETED`, `COORDINATION_PHASE_FAILED`, `COORDINATION_WAVE_STARTED`, `COORDINATION_WAVE_COMPLETED`, `COORDINATION_TOPOLOGY_RESOLVED`, `COORDINATION_CLEANUP_STARTED`, `COORDINATION_CLEANUP_COMPLETED`, `COORDINATION_CLEANUP_FAILED`, `COORDINATION_WAVE_BUILT`, `COORDINATION_FACTORY_BUILT`, and `COORDINATION_ATTRIBUTION_BUILT` from `events.coordination`, `WEB_REQUEST_START` and `WEB_SSRF_BLOCKED` from `events.web`, `DB_QUERY_START` and `DB_WRITE_BLOCKED` from `events.database`, `TERMINAL_COMMAND_START` and `TERMINAL_COMMAND_BLOCKED` from `events.terminal`, `SUB_CONSTRAINT_RESOLVED` and `SUB_CONSTRAINT_DENIED` from `events.sub_constraint`, `VERSION_SAVED`, `VERSION_SNAPSHOT_FAILED`, `VERSION_LISTED`, and `VERSION_NOT_FOUND` from `events.versioning`, `ANALYTICS_AGGREGATION_COMPUTED` and `ANALYTICS_RETRY_RATE_ALERT` from `events.analytics`, `CALL_CLASSIFICATION_COMPUTED` from `events.call_classification`, `QUOTA_THRESHOLD_ALERT` and `QUOTA_POLL_FAILED` from `events.quota`, `CONFLICT_DEBATE_EVALUATOR_FAILED` from `events.conflict`, `DELEGATION_LOOP_CIRCUIT_BACKOFF` and `DELEGATION_LOOP_CIRCUIT_PERSIST_FAILED` from `events.delegation`, `MEETING_EVENT_COOLDOWN_SKIPPED` and `MEETING_TASKS_CAPPED` from `events.meeting`, `PERSISTENCE_CIRCUIT_BREAKER_SAVED`, `PERSISTENCE_CIRCUIT_BREAKER_SAVE_FAILED`, `PERSISTENCE_CIRCUIT_BREAKER_LOADED`, `PERSISTENCE_CIRCUIT_BREAKER_LOAD_FAILED`, `PERSISTENCE_CIRCUIT_BREAKER_DELETED`, `PERSISTENCE_CIRCUIT_BREAKER_DELETE_FAILED`, `PERSISTENCE_PROJECT_COST_AGG_INCREMENTED`, `PERSISTENCE_PROJECT_COST_AGG_INCREMENT_FAILED`, `PERSISTENCE_PROJECT_COST_AGG_FETCHED`, `PERSISTENCE_PROJECT_COST_AGG_FETCH_FAILED`, and `PERSISTENCE_PROJECT_COST_AGG_DESERIALIZE_FAILED` from `events.persistence`, `METRICS_SCRAPE_COMPLETED`, `METRICS_SCRAPE_FAILED`, `METRICS_COLLECTOR_INITIALIZED`, `METRICS_COORDINATION_RECORDED`, `METRICS_OTLP_EXPORT_COMPLETED` and `METRICS_OTLP_FLUSHER_STOPPED` from `events.metrics`, `ORG_MEMORY_QUERY_START`, `ORG_MEMORY_QUERY_COMPLETE`, `ORG_MEMORY_QUERY_FAILED`, `ORG_MEMORY_WRITE_START`, `ORG_MEMORY_WRITE_COMPLETE`, `ORG_MEMORY_WRITE_DENIED`, `ORG_MEMORY_WRITE_FAILED`, `ORG_MEMORY_POLICIES_LISTED`, `ORG_MEMORY_BACKEND_CREATED`, `ORG_MEMORY_CONNECT_FAILED`, `ORG_MEMORY_DISCONNECT_FAILED`, `ORG_MEMORY_NOT_CONNECTED`, `ORG_MEMORY_ROW_PARSE_FAILED`, `ORG_MEMORY_CONFIG_INVALID`, `ORG_MEMORY_MODEL_INVALID`, `ORG_MEMORY_MVCC_PUBLISH_APPENDED`, `ORG_MEMORY_MVCC_RETRACT_APPENDED`, `ORG_MEMORY_MVCC_SNAPSHOT_AT_QUERIED`, and `ORG_MEMORY_MVCC_LOG_QUERIED` from `events.org_memory`). Each domain has its own module -- see `src/synthorg/observability/events/` for the full inventory of constants. Import directly: `from synthorg.observability.events. import EVENT_CONSTANT` +- **Event names**: always use constants from the domain-specific module under `synthorg.observability.events` (e.g., `API_REQUEST_STARTED` from `events.api`, `TOOL_INVOKE_START` from `events.tool`). Each domain has its own module -- see `src/synthorg/observability/events/` for the full inventory of constants. Import directly: `from synthorg.observability.events. import EVENT_CONSTANT` - **Structured kwargs**: always `logger.info(EVENT, key=value)` -- never `logger.info("msg %s", val)` - **All error paths** must log at WARNING or ERROR with context before raising - **All state transitions** must log at INFO diff --git a/docs/design/operations.md b/docs/design/operations.md index ee1b520c1d..94a6a62f7c 100644 --- a/docs/design/operations.md +++ b/docs/design/operations.md @@ -691,8 +691,14 @@ Per-category backend selection is implemented in `tools/sandbox/factory.py` via `build_sandbox_backends` (instantiates only the backends referenced by config), `resolve_sandbox_for_category` (looks up the correct backend for a `ToolCategory`), and `cleanup_sandbox_backends` (parallel cleanup with error isolation). The tool factory -(`build_default_tools_from_config`) wires `VERSION_CONTROL` category; other categories will -be wired as their tool builders are added. +(`build_default_tools_from_config`) wires tool categories. Core tools +(`FILE_SYSTEM`, `VERSION_CONTROL`, web, etc.) are part of the default toolset +and always registered. The +auxiliary categories `DESIGN`, `COMMUNICATION`, and `ANALYTICS` are opt-in: tools +are only registered when the corresponding config section is present, and some +individual tools additionally require a runtime dependency (e.g. image tools +require an ``ImageProvider``, notification tools require a dispatcher, analytics +query/metric tools require a provider or sink). Docker is optional -- only required when code execution, terminal, web, or database tools are enabled. File system and git tools work out of the box with subprocess isolation. This keeps diff --git a/docs/reference/claude-reference.md b/docs/reference/claude-reference.md index d72032f13b..5c0556ea0a 100644 --- a/docs/reference/claude-reference.md +++ b/docs/reference/claude-reference.md @@ -55,7 +55,7 @@ src/synthorg/ settings/ # Runtime-editable settings (DB > env > YAML > code), Fernet encryption, ConfigResolver, definitions/, subscribers/ (SecuritySubscriber for discovery allowlist hot-reload) security/ # Rule engine, audit log, output scanner, progressive trust, autonomy levels, timeout policies, LLM fallback evaluator, custom policy rules, risk scoring (pluggable RiskScorer protocol, multi-dimensional RiskScore, DefaultRiskScorer), enforcement modes (active/shadow/disabled via SecurityEnforcementMode), risk override (SecOps risk tier reclassification via RiskTierOverride + SecOpsRiskClassifier), SSRF violation tracking (SsrfViolation model, pending/allowed/denied status for self-healing discovery allowlist) templates/ # Pre-built company templates (inheritance tree), template merge engine, personality presets, preset discovery/CRUD service, model requirements, tier-to-model matching, locale-aware name generation, workflow config rendering, pack_loader (additive team packs), packs/ (built-in pack YAMLs), uses_packs composition - tools/ # Tool registry, built-in tools, git SSRF prevention, MCP bridge, sandbox factory (gVisor default overrides via merge_gvisor_defaults), invocation tracking, network_validator (shared SSRF), sub_constraints (per-level constraint models), sub_constraint_enforcer (granular enforcement), web/ (HTTP requests, HTML parsing, web search), database/ (SQL query, schema inspection), terminal/ (sandboxed shell commands), sandbox/ (4-domain SandboxPolicy model (filesystem/network/process/inference), SandboxRuntimeResolver (gVisor probe + per-category runtime resolution with fallback), SandboxCredentialManager (env var credential stripping), SandboxAuthProxy (LLM traffic auth proxy stub)) + tools/ # Tool registry, built-in tools, git SSRF prevention, MCP bridge, sandbox factory (gVisor default overrides via merge_gvisor_defaults), invocation tracking, network_validator (shared SSRF), sub_constraints (per-level constraint models), sub_constraint_enforcer (granular enforcement), web/ (HTTP requests, HTML parsing, web search), database/ (SQL query, schema inspection), terminal/ (sandboxed shell commands), design/ (image generation via ImageProvider protocol, diagram DSL generation, asset management), communication/ (SMTP email sending, notification dispatch via NotificationDispatcherProtocol, Jinja2 template formatting), analytics/ (data aggregation via AnalyticsProvider protocol, report generation, metric collection via MetricSink protocol), sandbox/ (4-domain SandboxPolicy model (filesystem/network/process/inference), SandboxRuntimeResolver (gVisor probe + per-category runtime resolution with fallback), SandboxCredentialManager (env var credential stripping), SandboxAuthProxy (LLM traffic auth proxy stub)) web/src/ # React 19 dashboard (see web/CLAUDE.md for full structure) cli/ # Go CLI binary (see cli/CLAUDE.md for full structure) diff --git a/src/synthorg/config/defaults.py b/src/synthorg/config/defaults.py index add15e6912..37e4075c83 100644 --- a/src/synthorg/config/defaults.py +++ b/src/synthorg/config/defaults.py @@ -48,4 +48,7 @@ def default_config_dict() -> dict[str, object]: "web": None, "database": None, "terminal": None, + "design_tools": None, + "communication_tools": None, + "analytics_tools": None, } diff --git a/src/synthorg/config/schema.py b/src/synthorg/config/schema.py index bef9eae91a..bb25f55f95 100644 --- a/src/synthorg/config/schema.py +++ b/src/synthorg/config/schema.py @@ -37,7 +37,10 @@ from synthorg.providers.enums import AuthType from synthorg.security.config import SecurityConfig from synthorg.security.trust.config import TrustConfig +from synthorg.tools.analytics.config import AnalyticsToolsConfig # noqa: TC001 +from synthorg.tools.communication.config import CommunicationToolsConfig # noqa: TC001 from synthorg.tools.database.config import DatabaseConfig # noqa: TC001 +from synthorg.tools.design.config import DesignToolsConfig # noqa: TC001 from synthorg.tools.git_url_validator import GitCloneNetworkPolicy from synthorg.tools.mcp.config import MCPConfig from synthorg.tools.sandbox.sandboxing_config import SandboxingConfig @@ -592,6 +595,11 @@ class RootConfig(BaseModel): tools). terminal: Terminal tool configuration (``None`` = default terminal config). + design_tools: Design tool configuration (``None`` = disabled). + communication_tools: Communication tool configuration + (``None`` = disabled). + analytics_tools: Analytics tool configuration + (``None`` = disabled). """ model_config = ConfigDict(frozen=True, allow_inf_nan=False) @@ -739,6 +747,18 @@ class RootConfig(BaseModel): default=None, description="Terminal tool configuration (None = default terminal config)", ) + design_tools: DesignToolsConfig | None = Field( + default=None, + description="Design tool configuration (None = disabled)", + ) + communication_tools: CommunicationToolsConfig | None = Field( + default=None, + description="Communication tool configuration (None = disabled)", + ) + analytics_tools: AnalyticsToolsConfig | None = Field( + default=None, + description="Analytics tool configuration (None = disabled)", + ) @model_validator(mode="after") def _validate_unique_agent_names(self) -> Self: diff --git a/src/synthorg/observability/events/analytics.py b/src/synthorg/observability/events/analytics.py index ceabb015eb..8d6059de7f 100644 --- a/src/synthorg/observability/events/analytics.py +++ b/src/synthorg/observability/events/analytics.py @@ -12,3 +12,23 @@ ANALYTICS_RETRY_RATE_ALERT: Final[str] = "analytics.retry_rate_alert" ANALYTICS_ORCHESTRATION_ALERT: Final[str] = "analytics.orchestration_alert" ANALYTICS_SERVICE_CREATED: Final[str] = "analytics.service_created" + +# Tool: data aggregation queries +ANALYTICS_TOOL_QUERY_START: Final[str] = "analytics.tool.query_start" +ANALYTICS_TOOL_QUERY_SUCCESS: Final[str] = "analytics.tool.query_success" +ANALYTICS_TOOL_QUERY_FAILED: Final[str] = "analytics.tool.query_failed" + +# Tool: report generation +ANALYTICS_TOOL_REPORT_START: Final[str] = "analytics.tool.report_start" +ANALYTICS_TOOL_REPORT_SUCCESS: Final[str] = "analytics.tool.report_success" +ANALYTICS_TOOL_REPORT_FAILED: Final[str] = "analytics.tool.report_failed" + +# Tool: metric collection +ANALYTICS_TOOL_METRIC_RECORDED: Final[str] = "analytics.tool.metric_recorded" +ANALYTICS_TOOL_METRIC_RECORD_FAILED: Final[str] = "analytics.tool.metric_record_failed" +ANALYTICS_TOOL_METRIC_NOT_ALLOWED: Final[str] = "analytics.tool.metric_not_allowed" + +# Tool: provider +ANALYTICS_TOOL_PROVIDER_NOT_CONFIGURED: Final[str] = ( + "analytics.tool.provider_not_configured" +) diff --git a/src/synthorg/observability/events/communication.py b/src/synthorg/observability/events/communication.py index f65916df11..277d6f2717 100644 --- a/src/synthorg/observability/events/communication.py +++ b/src/synthorg/observability/events/communication.py @@ -58,3 +58,34 @@ # Shutdown COMM_BUS_SHUTDOWN_SIGNAL: Final[str] = "communication.bus.shutdown_signal" + +# Tool: email sending +COMM_TOOL_EMAIL_SEND_START: Final[str] = "communication.tool.email.send_start" +COMM_TOOL_EMAIL_SEND_SUCCESS: Final[str] = "communication.tool.email.send_success" +COMM_TOOL_EMAIL_SEND_FAILED: Final[str] = "communication.tool.email.send_failed" +COMM_TOOL_EMAIL_VALIDATION_FAILED: Final[str] = ( + "communication.tool.email.validation_failed" +) + +# Tool: notification sending +COMM_TOOL_NOTIFICATION_SEND_START: Final[str] = ( + "communication.tool.notification.send_start" +) +COMM_TOOL_NOTIFICATION_SEND_SUCCESS: Final[str] = ( + "communication.tool.notification.send_success" +) +COMM_TOOL_NOTIFICATION_SEND_FAILED: Final[str] = ( + "communication.tool.notification.send_failed" +) + +# Tool: template rendering +COMM_TOOL_TEMPLATE_RENDER_START: Final[str] = "communication.tool.template.render_start" +COMM_TOOL_TEMPLATE_RENDER_SUCCESS: Final[str] = ( + "communication.tool.template.render_success" +) +COMM_TOOL_TEMPLATE_RENDER_FAILED: Final[str] = ( + "communication.tool.template.render_failed" +) +COMM_TOOL_TEMPLATE_RENDER_INVALID: Final[str] = ( + "communication.tool.template.render_invalid" +) diff --git a/src/synthorg/observability/events/design.py b/src/synthorg/observability/events/design.py new file mode 100644 index 0000000000..19bb5c047b --- /dev/null +++ b/src/synthorg/observability/events/design.py @@ -0,0 +1,25 @@ +"""Design tool event constants.""" + +from typing import Final + +# Image generation +DESIGN_IMAGE_GENERATION_START: Final[str] = "design.image.generation_start" +DESIGN_IMAGE_GENERATION_SUCCESS: Final[str] = "design.image.generation_success" +DESIGN_IMAGE_GENERATION_FAILED: Final[str] = "design.image.generation_failed" +DESIGN_IMAGE_GENERATION_TIMEOUT: Final[str] = "design.image.generation_timeout" + +# Diagram generation +DESIGN_DIAGRAM_GENERATION_START: Final[str] = "design.diagram.generation_start" +DESIGN_DIAGRAM_GENERATION_SUCCESS: Final[str] = "design.diagram.generation_success" +DESIGN_DIAGRAM_GENERATION_FAILED: Final[str] = "design.diagram.generation_failed" + +# Asset management +DESIGN_ASSET_STORED: Final[str] = "design.asset.stored" +DESIGN_ASSET_RETRIEVED: Final[str] = "design.asset.retrieved" +DESIGN_ASSET_DELETED: Final[str] = "design.asset.deleted" +DESIGN_ASSET_LISTED: Final[str] = "design.asset.listed" +DESIGN_ASSET_SEARCHED: Final[str] = "design.asset.searched" +DESIGN_ASSET_VALIDATION_FAILED: Final[str] = "design.asset.validation_failed" + +# Provider +DESIGN_PROVIDER_NOT_CONFIGURED: Final[str] = "design.provider.not_configured" diff --git a/src/synthorg/tools/analytics/__init__.py b/src/synthorg/tools/analytics/__init__.py new file mode 100644 index 0000000000..d8626dbc79 --- /dev/null +++ b/src/synthorg/tools/analytics/__init__.py @@ -0,0 +1,23 @@ +"""Built-in analytics tools for data aggregation, reporting, and metrics.""" + +from synthorg.tools.analytics.base_analytics_tool import BaseAnalyticsTool +from synthorg.tools.analytics.config import AnalyticsToolsConfig +from synthorg.tools.analytics.data_aggregator import ( + AnalyticsProvider, + DataAggregatorTool, +) +from synthorg.tools.analytics.metric_collector import ( + MetricCollectorTool, + MetricSink, +) +from synthorg.tools.analytics.report_generator import ReportGeneratorTool + +__all__ = [ + "AnalyticsProvider", + "AnalyticsToolsConfig", + "BaseAnalyticsTool", + "DataAggregatorTool", + "MetricCollectorTool", + "MetricSink", + "ReportGeneratorTool", +] diff --git a/src/synthorg/tools/analytics/base_analytics_tool.py b/src/synthorg/tools/analytics/base_analytics_tool.py new file mode 100644 index 0000000000..16f2d2b06f --- /dev/null +++ b/src/synthorg/tools/analytics/base_analytics_tool.py @@ -0,0 +1,67 @@ +"""Base class for analytics tools. + +Provides the common ``ToolCategory.ANALYTICS`` category, a +shared configuration reference, and a metric-name validation +helper. +""" + +from abc import ABC +from typing import Any + +from synthorg.core.enums import ToolCategory +from synthorg.tools.analytics.config import AnalyticsToolsConfig +from synthorg.tools.base import BaseTool + + +class BaseAnalyticsTool(BaseTool, ABC): + """Abstract base for all analytics tools. + + Sets ``category=ToolCategory.ANALYTICS`` and holds a shared + ``AnalyticsToolsConfig``. + """ + + def __init__( + self, + *, + name: str, + description: str = "", + parameters_schema: dict[str, Any] | None = None, + action_type: str | None = None, + config: AnalyticsToolsConfig | None = None, + ) -> None: + """Initialize an analytics tool with configuration. + + Args: + name: Tool name. + description: Human-readable description. + parameters_schema: JSON Schema for tool parameters. + action_type: Security action type override. + config: Analytics tool configuration. + """ + super().__init__( + name=name, + description=description, + category=ToolCategory.ANALYTICS, + parameters_schema=parameters_schema, + action_type=action_type, + ) + self._config = config or AnalyticsToolsConfig() + + @property + def config(self) -> AnalyticsToolsConfig: + """The analytics tool configuration.""" + return self._config + + def _is_metric_allowed(self, metric_name: str) -> bool: + """Check if a metric name is allowed by the whitelist. + + Args: + metric_name: Name of the metric to check. + + Returns: + ``True`` if the metric is allowed (or no whitelist + is configured). + """ + if self._config.allowed_metrics is None: + return True + return metric_name in self._config.allowed_metrics diff --git a/src/synthorg/tools/analytics/config.py b/src/synthorg/tools/analytics/config.py new file mode 100644 index 0000000000..c900cf9acf --- /dev/null +++ b/src/synthorg/tools/analytics/config.py @@ -0,0 +1,35 @@ +"""Configuration models for analytics tools.""" + +from pydantic import BaseModel, ConfigDict, Field + +from synthorg.core.types import NotBlankStr # noqa: TC001 + + +class AnalyticsToolsConfig(BaseModel): + """Top-level configuration for analytics tools. + + Attributes: + query_timeout: Maximum query execution time in seconds. + max_rows: Maximum rows returned from aggregation queries. + allowed_metrics: Optional whitelist of metric names agents + can query. ``None`` means all metrics are accessible. + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + query_timeout: float = Field( + default=60.0, + gt=0, + le=300.0, + description="Query timeout (seconds)", + ) + max_rows: int = Field( + default=10_000, + gt=0, + le=100_000, + description="Maximum rows in aggregation results", + ) + allowed_metrics: frozenset[NotBlankStr] | None = Field( + default=None, + description="Metric whitelist (None = all allowed)", + ) diff --git a/src/synthorg/tools/analytics/data_aggregator.py b/src/synthorg/tools/analytics/data_aggregator.py new file mode 100644 index 0000000000..d84309476b --- /dev/null +++ b/src/synthorg/tools/analytics/data_aggregator.py @@ -0,0 +1,366 @@ +"""Data aggregator tool -- query and aggregate analytics data via provider. + +The ``AnalyticsProvider`` protocol defines a vendor-agnostic interface +for querying analytics backends. No concrete implementation is +shipped -- users inject a provider at construction time. +""" + +import asyncio +import copy +from datetime import datetime +from typing import Any, Final, Protocol, runtime_checkable + +from synthorg.core.enums import ActionType +from synthorg.observability import get_logger +from synthorg.observability.events.analytics import ( + ANALYTICS_TOOL_PROVIDER_NOT_CONFIGURED, + ANALYTICS_TOOL_QUERY_FAILED, + ANALYTICS_TOOL_QUERY_START, + ANALYTICS_TOOL_QUERY_SUCCESS, +) +from synthorg.tools.analytics.base_analytics_tool import BaseAnalyticsTool +from synthorg.tools.analytics.config import AnalyticsToolsConfig # noqa: TC001 +from synthorg.tools.base import ToolExecutionResult + +logger = get_logger(__name__) + +_VALID_PERIODS: Final[frozenset[str]] = frozenset({"7d", "30d", "90d", "custom"}) + +_VALID_GROUP_BY: Final[frozenset[str]] = frozenset( + {"day", "week", "month", "agent", "department"} +) + + +@runtime_checkable +class AnalyticsProvider(Protocol): + """Abstracted analytics data provider protocol. + + Implementations must be async and return query results + as a dictionary. + """ + + async def query( + self, + *, + metrics: list[str], + period: str, + group_by: str | None = None, + start_date: str | None = None, + end_date: str | None = None, + ) -> dict[str, Any]: + """Query analytics data. + + Args: + metrics: Metric names to aggregate. + period: Time period (7d, 30d, 90d, or custom). + group_by: Optional grouping dimension. + start_date: Start date for custom period (ISO 8601). + end_date: End date for custom period (ISO 8601). + + Returns: + Query results as a dictionary. + """ + ... + + +_PARAMETERS_SCHEMA: Final[dict[str, Any]] = { + "type": "object", + "properties": { + "metrics": { + "type": "array", + "items": {"type": "string"}, + "description": ( + "Metric names to aggregate (e.g. 'total_cost', 'task_completion_rate')" + ), + }, + "period": { + "type": "string", + "enum": sorted(_VALID_PERIODS), + "description": "Time period for aggregation", + }, + "group_by": { + "type": "string", + "enum": sorted(_VALID_GROUP_BY), + "description": "Optional grouping dimension", + }, + "start_date": { + "type": "string", + "description": "Start date for custom period (ISO 8601)", + }, + "end_date": { + "type": "string", + "description": "End date for custom period (ISO 8601)", + }, + }, + "required": ["metrics", "period"], + "additionalProperties": False, +} + + +class DataAggregatorTool(BaseAnalyticsTool): + """Query and aggregate analytics data via a provider. + + Requires an ``AnalyticsProvider`` to be injected at construction + time. Validates metric names against the optional whitelist + in ``AnalyticsToolsConfig``. + + Examples: + Query metrics:: + + tool = DataAggregatorTool(provider=my_provider) + result = await tool.execute( + arguments={ + "metrics": ["total_cost", "task_count"], + "period": "7d", + "group_by": "day", + } + ) + """ + + def __init__( + self, + *, + provider: AnalyticsProvider | None = None, + config: AnalyticsToolsConfig | None = None, + ) -> None: + """Initialize the data aggregator tool. + + Args: + provider: Analytics data provider. ``None`` means + the tool will return an error on execution. + config: Analytics tool configuration. + """ + super().__init__( + name="data_aggregator", + description=( + "Query and aggregate analytics data " + "(costs, tasks, performance metrics)." + ), + parameters_schema=copy.deepcopy(_PARAMETERS_SCHEMA), + action_type=ActionType.CODE_READ, + config=config, + ) + self._provider = provider + + def _validate_query_params( + self, + metrics: list[str], + period: str, + group_by: str | None, + start_date: str | None, + end_date: str | None, + ) -> ToolExecutionResult | None: + """Validate query parameters. + + Returns a ``ToolExecutionResult`` error if validation fails, + or ``None`` if all parameters are valid. + """ + blocked = [m for m in metrics if not self._is_metric_allowed(m)] + if blocked: + logger.warning( + ANALYTICS_TOOL_QUERY_FAILED, + error="metrics_not_allowed", + blocked=blocked, + ) + return ToolExecutionResult( + content=( + f"Metrics not allowed: {blocked}. " + f"Allowed: {sorted(self._config.allowed_metrics or set())}" + ), + is_error=True, + ) + + if period not in _VALID_PERIODS: + logger.warning( + ANALYTICS_TOOL_QUERY_FAILED, + error="invalid_period", + period=period, + ) + return ToolExecutionResult( + content=( + f"Invalid period: {period!r}. " + f"Must be one of: {sorted(_VALID_PERIODS)}" + ), + is_error=True, + ) + + if period == "custom" and (not start_date or not end_date): + logger.warning( + ANALYTICS_TOOL_QUERY_FAILED, + error="missing_custom_dates", + ) + return ToolExecutionResult( + content="Custom period requires both start_date and end_date.", + is_error=True, + ) + + for date_label, date_val in ( + ("start_date", start_date), + ("end_date", end_date), + ): + if date_val is not None: + try: + datetime.fromisoformat(date_val) + except ValueError: + logger.warning( + ANALYTICS_TOOL_QUERY_FAILED, + error="invalid_date", + field=date_label, + value=date_val, + ) + return ToolExecutionResult( + content=( + f"Invalid {date_label}: {date_val!r}. " + f"Must be ISO 8601 format." + ), + is_error=True, + ) + + if group_by is not None and group_by not in _VALID_GROUP_BY: + logger.warning( + ANALYTICS_TOOL_QUERY_FAILED, + error="invalid_group_by", + group_by=group_by, + ) + return ToolExecutionResult( + content=( + f"Invalid group_by: {group_by!r}. " + f"Must be one of: {sorted(_VALID_GROUP_BY)}" + ), + is_error=True, + ) + + return None + + async def execute( # noqa: PLR0911 + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + """Query analytics data. + + Args: + arguments: Must contain ``metrics`` and ``period``; + optionally ``group_by``, ``start_date``, ``end_date``. + + Returns: + A ``ToolExecutionResult`` with aggregated data. + """ + if self._provider is None: + logger.warning( + ANALYTICS_TOOL_PROVIDER_NOT_CONFIGURED, + tool="data_aggregator", + ) + return ToolExecutionResult( + content=( + "Analytics queries require a configured provider. " + "No AnalyticsProvider has been injected." + ), + is_error=True, + ) + + metrics = arguments.get("metrics") + period = arguments.get("period") + if not isinstance(metrics, list) or not metrics: + logger.warning( + ANALYTICS_TOOL_QUERY_FAILED, + error="missing_or_invalid_metrics", + ) + return ToolExecutionResult( + content="'metrics' must be a non-empty list of strings.", + is_error=True, + ) + if not isinstance(period, str) or not period: + logger.warning( + ANALYTICS_TOOL_QUERY_FAILED, + error="missing_or_invalid_period", + ) + return ToolExecutionResult( + content="'period' must be a non-empty string.", + is_error=True, + ) + group_by: str | None = arguments.get("group_by") + start_date: str | None = arguments.get("start_date") + end_date: str | None = arguments.get("end_date") + + error = self._validate_query_params( + metrics, + period, + group_by, + start_date, + end_date, + ) + if error is not None: + return error + + logger.info( + ANALYTICS_TOOL_QUERY_START, + metrics=metrics, + period=period, + group_by=group_by, + ) + + try: + data = await asyncio.wait_for( + self._provider.query( + metrics=metrics, + period=period, + group_by=group_by, + start_date=start_date, + end_date=end_date, + ), + timeout=self._config.query_timeout, + ) + except TimeoutError: + logger.warning( + ANALYTICS_TOOL_QUERY_FAILED, + error="query_timeout", + timeout=self._config.query_timeout, + metrics=metrics, + ) + return ToolExecutionResult( + content=( + f"Analytics query timed out after {self._config.query_timeout}s" + ), + is_error=True, + ) + except MemoryError, RecursionError: + raise + except Exception: + logger.warning( + ANALYTICS_TOOL_QUERY_FAILED, + error="provider_error", + metrics=metrics, + period=period, + exc_info=True, + ) + return ToolExecutionResult( + content="Analytics query failed.", + is_error=True, + ) + + # Enforce max_rows without mutating the provider's result. + sanitized = { + k: ( + v[: self._config.max_rows] + if isinstance(v, list) and len(v) > self._config.max_rows + else v + ) + for k, v in data.items() + } + + logger.info( + ANALYTICS_TOOL_QUERY_SUCCESS, + metrics=metrics, + result_keys=sorted(sanitized.keys()), + ) + + # Format results as readable text + lines = [f"Analytics query results ({period}):"] + for key, value in sorted(sanitized.items()): + lines.append(f" {key}: {value}") + + return ToolExecutionResult( + content="\n".join(lines), + metadata=sanitized, + ) diff --git a/src/synthorg/tools/analytics/metric_collector.py b/src/synthorg/tools/analytics/metric_collector.py new file mode 100644 index 0000000000..418106e6b5 --- /dev/null +++ b/src/synthorg/tools/analytics/metric_collector.py @@ -0,0 +1,247 @@ +"""Metric collector tool -- record custom metrics via a sink. + +The ``MetricSink`` protocol defines a vendor-agnostic interface +for recording metrics. No concrete implementation is shipped -- +users inject a sink at construction time. +""" + +import copy +import math +from typing import Any, Final, Protocol, runtime_checkable + +from synthorg.observability import get_logger +from synthorg.observability.events.analytics import ( + ANALYTICS_TOOL_METRIC_NOT_ALLOWED, + ANALYTICS_TOOL_METRIC_RECORD_FAILED, + ANALYTICS_TOOL_METRIC_RECORDED, +) +from synthorg.tools.analytics.base_analytics_tool import BaseAnalyticsTool +from synthorg.tools.analytics.config import AnalyticsToolsConfig # noqa: TC001 +from synthorg.tools.base import ToolExecutionResult + +logger = get_logger(__name__) + + +@runtime_checkable +class MetricSink(Protocol): + """Abstracted metric recording sink protocol. + + Implementations must be async and accept individual + metric data points. + """ + + async def record( + self, + *, + name: str, + value: float, + tags: dict[str, str] | None = None, + unit: str | None = None, + ) -> None: + """Record a metric data point. + + Args: + name: Metric name. + value: Metric value. + tags: Optional key-value tags for the data point. + unit: Optional measurement unit. + """ + ... + + +_PARAMETERS_SCHEMA: Final[dict[str, Any]] = { + "type": "object", + "properties": { + "metric_name": { + "type": "string", + "description": "Name of the metric to record", + }, + "value": { + "type": "number", + "description": "Metric value", + }, + "tags": { + "type": "object", + "additionalProperties": {"type": "string"}, + "description": "Optional key-value tags", + }, + "unit": { + "type": "string", + "description": "Optional measurement unit (e.g. 'seconds', 'bytes')", + }, + }, + "required": ["metric_name", "value"], + "additionalProperties": False, +} + + +class MetricCollectorTool(BaseAnalyticsTool): + """Record custom metrics via an abstracted sink. + + Allows agents to record observations and measurements + that are forwarded to the configured metric backend. + + Examples: + Record a metric:: + + tool = MetricCollectorTool(sink=my_sink) + result = await tool.execute( + arguments={ + "metric_name": "response_time", + "value": 1.23, + "unit": "seconds", + "tags": {"endpoint": "/api/tasks"}, + } + ) + """ + + def __init__( + self, + *, + sink: MetricSink | None = None, + config: AnalyticsToolsConfig | None = None, + ) -> None: + """Initialize the metric collector tool. + + Args: + sink: Metric recording sink. ``None`` means the + tool will return an error on execution. + config: Analytics tool configuration. + """ + super().__init__( + name="metric_collector", + description=( + "Record custom metrics (counters, gauges, timings) " + "to the configured metric backend." + ), + parameters_schema=copy.deepcopy(_PARAMETERS_SCHEMA), + action_type="metrics:record", + config=config, + ) + self._sink = sink + + async def execute( # noqa: PLR0911 + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + """Record a metric data point. + + Args: + arguments: Must contain ``metric_name`` and ``value``; + optionally ``tags`` and ``unit``. + + Returns: + A ``ToolExecutionResult`` confirming the recording. + """ + if self._sink is None: + logger.warning( + ANALYTICS_TOOL_METRIC_RECORD_FAILED, + metric_name=arguments.get("metric_name", "unknown"), + error="sink_not_configured", + ) + return ToolExecutionResult( + content=( + "Metric recording requires a configured sink. " + "No MetricSink has been injected." + ), + is_error=True, + ) + + metric_name = arguments.get("metric_name") + value = arguments.get("value") + if not isinstance(metric_name, str) or not metric_name.strip(): + logger.warning( + ANALYTICS_TOOL_METRIC_RECORD_FAILED, + error="missing_or_invalid_metric_name", + ) + return ToolExecutionResult( + content="'metric_name' must be a non-empty string.", + is_error=True, + ) + if isinstance(value, bool) or not isinstance(value, int | float): + logger.warning( + ANALYTICS_TOOL_METRIC_RECORD_FAILED, + error="invalid_value_type", + ) + return ToolExecutionResult( + content="'value' must be a number (not bool).", + is_error=True, + ) + value = float(value) + raw_tags = arguments.get("tags") + tags: dict[str, str] = raw_tags if isinstance(raw_tags, dict) else {} + unit = arguments.get("unit") + if unit is not None and not isinstance(unit, str): + logger.warning( + ANALYTICS_TOOL_METRIC_RECORD_FAILED, + error="invalid_unit_type", + ) + return ToolExecutionResult( + content="'unit' must be a string or null.", + is_error=True, + ) + + if not self._is_metric_allowed(metric_name): + logger.warning( + ANALYTICS_TOOL_METRIC_NOT_ALLOWED, + metric_name=metric_name, + ) + return ToolExecutionResult( + content=( + f"Metric not allowed: {metric_name!r}. " + f"Allowed: {sorted(self._config.allowed_metrics or set())}" + ), + is_error=True, + ) + + if not math.isfinite(value): + logger.warning( + ANALYTICS_TOOL_METRIC_RECORD_FAILED, + metric_name=metric_name, + error="non_finite_value", + value=str(value), + ) + return ToolExecutionResult( + content=(f"Metric value must be finite: {metric_name!r} got {value}"), + is_error=True, + ) + + try: + await self._sink.record( + name=metric_name, + value=value, + tags=tags, + unit=unit, + ) + except MemoryError, RecursionError: + raise + except Exception: + logger.warning( + ANALYTICS_TOOL_METRIC_RECORD_FAILED, + metric_name=metric_name, + error="sink_error", + exc_info=True, + ) + return ToolExecutionResult( + content="Metric recording failed.", + is_error=True, + ) + + logger.info( + ANALYTICS_TOOL_METRIC_RECORDED, + metric_name=metric_name, + value=value, + unit=unit, + ) + + unit_suffix = f" {unit}" if unit else "" + return ToolExecutionResult( + content=(f"Metric recorded: {metric_name} = {value}{unit_suffix}"), + metadata={ + "metric_name": metric_name, + "value": value, + "tags": tags, + "unit": unit, + }, + ) diff --git a/src/synthorg/tools/analytics/report_generator.py b/src/synthorg/tools/analytics/report_generator.py new file mode 100644 index 0000000000..5e0cd83af5 --- /dev/null +++ b/src/synthorg/tools/analytics/report_generator.py @@ -0,0 +1,346 @@ +"""Report generator tool -- produce formatted analytics reports. + +Delegates data fetching to an ``AnalyticsProvider`` and formats +the results into human-readable reports in text, markdown, or JSON. +""" + +import asyncio +import copy +import json +from typing import Any, Final + +from synthorg.core.enums import ActionType +from synthorg.observability import get_logger +from synthorg.observability.events.analytics import ( + ANALYTICS_TOOL_PROVIDER_NOT_CONFIGURED, + ANALYTICS_TOOL_REPORT_FAILED, + ANALYTICS_TOOL_REPORT_START, + ANALYTICS_TOOL_REPORT_SUCCESS, +) +from synthorg.tools.analytics.base_analytics_tool import BaseAnalyticsTool +from synthorg.tools.analytics.config import AnalyticsToolsConfig # noqa: TC001 +from synthorg.tools.analytics.data_aggregator import ( + AnalyticsProvider, # noqa: TC001 +) +from synthorg.tools.base import ToolExecutionResult + +logger = get_logger(__name__) + +_REPORT_TYPES: Final[frozenset[str]] = frozenset( + { + "budget_summary", + "performance", + "trend_analysis", + "cost_breakdown", + } +) + +_OUTPUT_FORMATS: Final[frozenset[str]] = frozenset({"text", "markdown", "json"}) + +_VALID_PERIODS: Final[frozenset[str]] = frozenset({"7d", "30d", "90d", "ytd"}) + +_REPORT_METRICS: Final[dict[str, list[str]]] = { + "budget_summary": ["total_cost", "budget_remaining", "burn_rate"], + "performance": [ + "task_completion_rate", + "average_latency", + "error_rate", + ], + "trend_analysis": ["total_cost", "task_count", "active_agents"], + "cost_breakdown": [ + "cost_by_agent", + "cost_by_department", + "cost_by_model", + ], +} + +_PARAMETERS_SCHEMA: Final[dict[str, Any]] = { + "type": "object", + "properties": { + "report_type": { + "type": "string", + "enum": sorted(_REPORT_TYPES), + "description": "Type of report to generate", + }, + "period": { + "type": "string", + "enum": sorted(_VALID_PERIODS), + "description": "Reporting period", + }, + "format": { + "type": "string", + "enum": sorted(_OUTPUT_FORMATS), + "description": "Output format (default: markdown)", + "default": "markdown", + }, + }, + "required": ["report_type", "period"], + "additionalProperties": False, +} + + +class ReportGeneratorTool(BaseAnalyticsTool): + """Generate formatted analytics reports. + + Queries the analytics provider for relevant metrics and + formats the results into a structured report. + + Examples: + Generate a budget report:: + + tool = ReportGeneratorTool(provider=my_provider) + result = await tool.execute( + arguments={ + "report_type": "budget_summary", + "period": "30d", + "format": "markdown", + } + ) + """ + + def __init__( + self, + *, + provider: AnalyticsProvider | None = None, + config: AnalyticsToolsConfig | None = None, + ) -> None: + """Initialize the report generator tool. + + Args: + provider: Analytics data provider. ``None`` means + the tool will return an error on execution. + config: Analytics tool configuration. + """ + super().__init__( + name="report_generator", + description=( + "Generate formatted analytics reports " + "(budget, performance, trends, cost breakdown)." + ), + parameters_schema=copy.deepcopy(_PARAMETERS_SCHEMA), + action_type=ActionType.CODE_READ, + config=config, + ) + self._provider = provider + + async def execute( # noqa: PLR0911, PLR0912, C901 + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + """Generate an analytics report. + + Args: + arguments: Must contain ``report_type`` and ``period``; + optionally ``format``. + + Returns: + A ``ToolExecutionResult`` with the formatted report. + """ + if self._provider is None: + logger.warning( + ANALYTICS_TOOL_PROVIDER_NOT_CONFIGURED, + tool="report_generator", + ) + return ToolExecutionResult( + content=( + "Report generation requires a configured provider. " + "No AnalyticsProvider has been injected." + ), + is_error=True, + ) + + report_type = arguments.get("report_type") + period = arguments.get("period") + if not isinstance(report_type, str): + logger.warning( + ANALYTICS_TOOL_REPORT_FAILED, + error="missing_or_invalid_report_type", + ) + return ToolExecutionResult( + content="'report_type' must be a string.", + is_error=True, + ) + if not isinstance(period, str): + logger.warning( + ANALYTICS_TOOL_REPORT_FAILED, + error="missing_or_invalid_period", + ) + return ToolExecutionResult( + content="'period' must be a string.", + is_error=True, + ) + output_format: str = arguments.get("format", "markdown") + + if report_type not in _REPORT_TYPES: + logger.warning( + ANALYTICS_TOOL_REPORT_FAILED, + error="invalid_report_type", + report_type=report_type, + ) + return ToolExecutionResult( + content=( + f"Invalid report_type: {report_type!r}. " + f"Must be one of: {sorted(_REPORT_TYPES)}" + ), + is_error=True, + ) + + if period not in _VALID_PERIODS: + logger.warning( + ANALYTICS_TOOL_REPORT_FAILED, + error="invalid_period", + period=period, + ) + return ToolExecutionResult( + content=( + f"Invalid period: {period!r}. " + f"Must be one of: {sorted(_VALID_PERIODS)}" + ), + is_error=True, + ) + + if output_format not in _OUTPUT_FORMATS: + logger.warning( + ANALYTICS_TOOL_REPORT_FAILED, + error="invalid_output_format", + output_format=output_format, + ) + return ToolExecutionResult( + content=( + f"Invalid format: {output_format!r}. " + f"Must be one of: {sorted(_OUTPUT_FORMATS)}" + ), + is_error=True, + ) + + metrics = list(_REPORT_METRICS.get(report_type, [])) + + if self._config.allowed_metrics is not None: + blocked = [m for m in metrics if m not in self._config.allowed_metrics] + if blocked: + logger.warning( + ANALYTICS_TOOL_REPORT_FAILED, + error="metrics_not_allowed", + report_type=report_type, + blocked_metrics=blocked, + ) + return ToolExecutionResult( + content=( + f"Report type {report_type!r} requires metrics " + f"not in the allowed list: {blocked}" + ), + is_error=True, + ) + + logger.info( + ANALYTICS_TOOL_REPORT_START, + report_type=report_type, + period=period, + output_format=output_format, + ) + + try: + data = await asyncio.wait_for( + self._provider.query( + metrics=metrics, + period=period, + ), + timeout=self._config.query_timeout, + ) + except TimeoutError: + logger.warning( + ANALYTICS_TOOL_REPORT_FAILED, + error="query_timeout", + timeout=self._config.query_timeout, + report_type=report_type, + ) + return ToolExecutionResult( + content=(f"Report query timed out after {self._config.query_timeout}s"), + is_error=True, + ) + except MemoryError, RecursionError: + raise + except Exception as exc: + logger.warning( + ANALYTICS_TOOL_REPORT_FAILED, + error=str(exc), + report_type=report_type, + ) + return ToolExecutionResult( + content=f"Report generation failed: {exc}", + is_error=True, + ) + + try: + report = self._format_report(report_type, period, data, output_format) + except MemoryError, RecursionError: + raise + except Exception as exc: + logger.warning( + ANALYTICS_TOOL_REPORT_FAILED, + error=str(exc), + report_type=report_type, + ) + return ToolExecutionResult( + content=f"Report formatting failed: {exc}", + is_error=True, + ) + + logger.info( + ANALYTICS_TOOL_REPORT_SUCCESS, + report_type=report_type, + output_length=len(report), + ) + + return ToolExecutionResult( + content=report, + metadata={ + "report_type": report_type, + "period": period, + "format": output_format, + }, + ) + + @staticmethod + def _format_report( + report_type: str, + period: str, + data: dict[str, Any], + output_format: str, + ) -> str: + """Format report data into the requested output format. + + Args: + report_type: Type of report. + period: Reporting period. + data: Raw data from the analytics provider. + output_format: Desired output format. + + Returns: + Formatted report string. + """ + if output_format == "json": + return json.dumps( + { + "report_type": report_type, + "period": period, + "data": data, + }, + indent=2, + default=str, + ) + + title = report_type.replace("_", " ").title() + + if output_format == "markdown": + lines = [f"# {title} Report", "", f"**Period:** {period}", ""] + for key, value in sorted(data.items()): + lines.append(f"- **{key}:** {value}") + return "\n".join(lines) + + # Plain text + lines = [f"{title} Report", f"Period: {period}", ""] + for key, value in sorted(data.items()): + lines.append(f" {key}: {value}") + return "\n".join(lines) diff --git a/src/synthorg/tools/communication/__init__.py b/src/synthorg/tools/communication/__init__.py new file mode 100644 index 0000000000..ff46ec4539 --- /dev/null +++ b/src/synthorg/tools/communication/__init__.py @@ -0,0 +1,27 @@ +"""Built-in communication tools for email, notifications, and messaging.""" + +from synthorg.tools.communication.base_communication_tool import ( + BaseCommunicationTool, +) +from synthorg.tools.communication.config import ( + CommunicationToolsConfig, + EmailConfig, +) +from synthorg.tools.communication.email_sender import EmailSenderTool +from synthorg.tools.communication.notification_sender import ( + NotificationDispatcherProtocol, + NotificationSenderTool, +) +from synthorg.tools.communication.template_formatter import ( + TemplateFormatterTool, +) + +__all__ = [ + "BaseCommunicationTool", + "CommunicationToolsConfig", + "EmailConfig", + "EmailSenderTool", + "NotificationDispatcherProtocol", + "NotificationSenderTool", + "TemplateFormatterTool", +] diff --git a/src/synthorg/tools/communication/base_communication_tool.py b/src/synthorg/tools/communication/base_communication_tool.py new file mode 100644 index 0000000000..0b04a6d3ea --- /dev/null +++ b/src/synthorg/tools/communication/base_communication_tool.py @@ -0,0 +1,52 @@ +"""Base class for communication tools. + +Provides the common ``ToolCategory.COMMUNICATION`` category and +a shared configuration reference. +""" + +from abc import ABC +from typing import Any + +from synthorg.core.enums import ToolCategory +from synthorg.tools.base import BaseTool +from synthorg.tools.communication.config import CommunicationToolsConfig + + +class BaseCommunicationTool(BaseTool, ABC): + """Abstract base for all communication tools. + + Sets ``category=ToolCategory.COMMUNICATION`` and holds a shared + ``CommunicationToolsConfig``. + """ + + def __init__( + self, + *, + name: str, + description: str = "", + parameters_schema: dict[str, Any] | None = None, + action_type: str | None = None, + config: CommunicationToolsConfig | None = None, + ) -> None: + """Initialize a communication tool with configuration. + + Args: + name: Tool name. + description: Human-readable description. + parameters_schema: JSON Schema for tool parameters. + action_type: Security action type override. + config: Communication tool configuration. + """ + super().__init__( + name=name, + description=description, + category=ToolCategory.COMMUNICATION, + parameters_schema=parameters_schema, + action_type=action_type, + ) + self._config = config or CommunicationToolsConfig() + + @property + def config(self) -> CommunicationToolsConfig: + """The communication tool configuration.""" + return self._config diff --git a/src/synthorg/tools/communication/config.py b/src/synthorg/tools/communication/config.py new file mode 100644 index 0000000000..9fbe8908de --- /dev/null +++ b/src/synthorg/tools/communication/config.py @@ -0,0 +1,111 @@ +"""Configuration models for communication tools.""" + +from typing import Self + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from synthorg.core.types import NotBlankStr # noqa: TC001 +from synthorg.observability import get_logger +from synthorg.observability.events.communication import ( + COMM_TOOL_EMAIL_VALIDATION_FAILED, +) + +logger = get_logger(__name__) + + +class EmailConfig(BaseModel): + """SMTP email configuration. + + Attributes: + host: SMTP server hostname. + port: SMTP server port. + username: SMTP authentication username. + password: SMTP authentication password. + from_address: Sender email address. + use_tls: Whether to use STARTTLS. + use_implicit_tls: Whether to use implicit TLS (SMTP_SSL, + typically port 465). Mutually exclusive with ``use_tls``. + smtp_timeout: SMTP connection timeout in seconds. + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + host: NotBlankStr = Field(description="SMTP server hostname") + port: int = Field( + default=587, + ge=1, + le=65535, + description="SMTP server port", + ) + username: NotBlankStr | None = Field( + default=None, + description="SMTP authentication username", + ) + password: NotBlankStr | None = Field( + default=None, + repr=False, + description="SMTP authentication password", + ) + from_address: NotBlankStr = Field( + description="Sender email address", + ) + use_tls: bool = Field( + default=True, + description="Whether to use STARTTLS", + ) + use_implicit_tls: bool = Field( + default=False, + description="Use implicit TLS (SMTP_SSL, port 465)", + ) + smtp_timeout: float = Field( + default=10.0, + gt=0, + le=120.0, + description="SMTP connection timeout (seconds)", + ) + + @model_validator(mode="after") + def _validate_auth_fields(self) -> Self: + """Validate credential pairing and TLS mutual exclusivity.""" + has_user = self.username is not None + has_pass = self.password is not None + if has_user != has_pass: + logger.warning( + COMM_TOOL_EMAIL_VALIDATION_FAILED, + reason="partial_credentials", + has_username=has_user, + has_password=has_pass, + ) + msg = "SMTP username and password must both be provided or both be None" + raise ValueError(msg) + if self.use_tls and self.use_implicit_tls: + logger.warning( + COMM_TOOL_EMAIL_VALIDATION_FAILED, + reason="tls_mutual_exclusion", + ) + msg = "use_tls and use_implicit_tls are mutually exclusive" + raise ValueError(msg) + return self + + +class CommunicationToolsConfig(BaseModel): + """Top-level configuration for communication tools. + + Attributes: + email: SMTP email configuration. ``None`` disables the + email sender tool. + max_recipients: Maximum number of recipients per email. + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + email: EmailConfig | None = Field( + default=None, + description="SMTP email config (None = email tool disabled)", + ) + max_recipients: int = Field( + default=100, + gt=0, + le=1000, + description="Maximum recipients per email", + ) diff --git a/src/synthorg/tools/communication/email_sender.py b/src/synthorg/tools/communication/email_sender.py new file mode 100644 index 0000000000..95f26d59b4 --- /dev/null +++ b/src/synthorg/tools/communication/email_sender.py @@ -0,0 +1,315 @@ +"""Email sender tool -- send emails via SMTP. + +Uses stdlib ``smtplib`` wrapped in ``asyncio.to_thread`` to avoid +blocking the event loop, following the same pattern as +``EmailNotificationSink``. +""" + +import asyncio +import copy +import re +import smtplib +import ssl +from email.message import EmailMessage +from typing import TYPE_CHECKING, Any, Final + +from synthorg.core.enums import ActionType +from synthorg.observability import get_logger +from synthorg.observability.events.communication import ( + COMM_TOOL_EMAIL_SEND_FAILED, + COMM_TOOL_EMAIL_SEND_START, + COMM_TOOL_EMAIL_SEND_SUCCESS, + COMM_TOOL_EMAIL_VALIDATION_FAILED, +) +from synthorg.tools.base import ToolExecutionResult +from synthorg.tools.communication.base_communication_tool import ( + BaseCommunicationTool, +) +from synthorg.tools.communication.config import ( + CommunicationToolsConfig, # noqa: TC001 +) + +if TYPE_CHECKING: + from synthorg.tools.communication.config import EmailConfig + +logger = get_logger(__name__) + +_CONTROL_CHAR_RE: Final[re.Pattern[str]] = re.compile(r"[\x00-\x1f\x7f]") + +# Reject addresses with newlines/carriage returns (header injection). +_UNSAFE_ADDR_RE: Final[re.Pattern[str]] = re.compile(r"[\r\n]") + +_PARAMETERS_SCHEMA: Final[dict[str, Any]] = { + "type": "object", + "properties": { + "to": { + "type": "array", + "items": {"type": "string"}, + "description": "Recipient email addresses", + }, + "cc": { + "type": "array", + "items": {"type": "string"}, + "description": "CC email addresses", + }, + "bcc": { + "type": "array", + "items": {"type": "string"}, + "description": "BCC email addresses", + }, + "subject": { + "type": "string", + "description": "Email subject line", + }, + "body": { + "type": "string", + "description": "Email body content", + "default": "", + }, + "body_is_html": { + "type": "boolean", + "description": "Whether body is HTML (default: plain text)", + "default": False, + }, + }, + "required": ["to", "subject"], + "additionalProperties": False, +} + + +class EmailSenderTool(BaseCommunicationTool): + """Send emails via SMTP. + + Requires ``EmailConfig`` in the communication tools config. + Uses stdlib ``smtplib`` with ``asyncio.to_thread`` for + non-blocking execution. + + Examples: + Send a plain text email:: + + tool = EmailSenderTool(config=comm_config) + result = await tool.execute( + arguments={ + "to": ["user@example.com"], + "subject": "Hello", + "body": "World", + } + ) + """ + + def __init__( + self, + *, + config: CommunicationToolsConfig | None = None, + ) -> None: + """Initialize the email sender tool. + + Args: + config: Communication tool configuration with email + settings. + """ + super().__init__( + name="email_sender", + description=( + "Send emails via SMTP. Supports plain text and HTML body content." + ), + parameters_schema=copy.deepcopy(_PARAMETERS_SCHEMA), + action_type=ActionType.COMMS_EXTERNAL, + config=config, + ) + + async def execute( # noqa: PLR0911 + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + """Send an email. + + Args: + arguments: Must contain ``to`` and ``subject``; + optionally ``cc``, ``bcc``, ``body``, + ``body_is_html``. + + Returns: + A ``ToolExecutionResult`` with delivery status. + """ + email_config = self._config.email + if email_config is None: + logger.warning( + COMM_TOOL_EMAIL_SEND_FAILED, + error="email_not_configured", + ) + return ToolExecutionResult( + content=( + "Email sending requires SMTP configuration. " + "Set 'email' in CommunicationToolsConfig." + ), + is_error=True, + ) + + to_addrs = arguments.get("to") + if not isinstance(to_addrs, list): + logger.warning( + COMM_TOOL_EMAIL_VALIDATION_FAILED, + reason="invalid_to", + ) + return ToolExecutionResult( + content="'to' must be a list of email addresses.", + is_error=True, + ) + cc_addrs: list[str] = arguments.get("cc") or [] + bcc_addrs: list[str] = arguments.get("bcc") or [] + subject = arguments.get("subject") + if not isinstance(subject, str): + logger.warning( + COMM_TOOL_EMAIL_VALIDATION_FAILED, + reason="invalid_subject", + ) + return ToolExecutionResult( + content="'subject' must be a string.", + is_error=True, + ) + body: str = arguments.get("body", "") + body_is_html: bool = arguments.get("body_is_html", False) + + all_recipients = to_addrs + cc_addrs + bcc_addrs + if not all_recipients: + logger.warning( + COMM_TOOL_EMAIL_VALIDATION_FAILED, + reason="no_recipients", + ) + return ToolExecutionResult( + content="At least one recipient is required.", + is_error=True, + ) + + if len(all_recipients) > self._config.max_recipients: + logger.warning( + COMM_TOOL_EMAIL_VALIDATION_FAILED, + reason="too_many_recipients", + count=len(all_recipients), + limit=self._config.max_recipients, + ) + return ToolExecutionResult( + content=( + f"Too many recipients: {len(all_recipients)} " + f"(max {self._config.max_recipients})" + ), + is_error=True, + ) + + # Reject addresses with newlines (header injection prevention). + for addr in [*all_recipients, email_config.from_address]: + if _UNSAFE_ADDR_RE.search(addr): + logger.warning( + COMM_TOOL_EMAIL_VALIDATION_FAILED, + reason="unsafe_address", + ) + return ToolExecutionResult( + content="Email address contains invalid characters.", + is_error=True, + ) + + logger.info( + COMM_TOOL_EMAIL_SEND_START, + to_count=len(to_addrs), + cc_count=len(cc_addrs), + bcc_count=len(bcc_addrs), + subject_length=len(subject), + ) + + try: + await asyncio.to_thread( + self._send_sync, + email_config=email_config, + to_addrs=to_addrs, + cc_addrs=cc_addrs, + all_recipients=all_recipients, + subject=subject, + body=body, + body_is_html=body_is_html, + ) + except MemoryError, RecursionError: + raise + except Exception: + logger.warning( + COMM_TOOL_EMAIL_SEND_FAILED, + error="smtp_error", + recipient_count=len(all_recipients), + exc_info=True, + ) + return ToolExecutionResult( + content="Email sending failed.", + is_error=True, + ) + + logger.info( + COMM_TOOL_EMAIL_SEND_SUCCESS, + recipient_count=len(all_recipients), + ) + + return ToolExecutionResult( + content=(f"Email sent successfully to {len(all_recipients)} recipient(s)."), + metadata={ + "to": to_addrs, + "cc": cc_addrs, + "bcc_count": len(bcc_addrs), + "subject": subject, + }, + ) + + @staticmethod + def _send_sync( # noqa: PLR0913 + *, + email_config: EmailConfig, + to_addrs: list[str], + cc_addrs: list[str], + all_recipients: list[str], + subject: str, + body: str, + body_is_html: bool, + ) -> None: + """Synchronous SMTP send (runs in a thread). + + Args: + email_config: EmailConfig with SMTP settings. + to_addrs: Primary recipients. + cc_addrs: CC recipients. + all_recipients: Combined recipient list for envelope. + subject: Email subject. + body: Email body. + body_is_html: Whether body is HTML. + """ + safe_subject = _CONTROL_CHAR_RE.sub("", subject) + msg = EmailMessage() + msg["Subject"] = safe_subject + msg["From"] = email_config.from_address + msg["To"] = ", ".join(to_addrs) + if cc_addrs: + msg["Cc"] = ", ".join(cc_addrs) + + if body_is_html: + msg.set_content(body, subtype="html") + else: + msg.set_content(body) + + timeout = email_config.smtp_timeout + context = ssl.create_default_context() + smtp_conn: smtplib.SMTP + if email_config.use_implicit_tls: + smtp_conn = smtplib.SMTP_SSL( + email_config.host, + email_config.port, + timeout=timeout, + context=context, + ) + else: + smtp_conn = smtplib.SMTP( + email_config.host, email_config.port, timeout=timeout + ) + with smtp_conn as smtp: + if not email_config.use_implicit_tls and email_config.use_tls: + smtp.starttls(context=context) + if email_config.username and email_config.password: + smtp.login(email_config.username, email_config.password) + smtp.send_message(msg, to_addrs=all_recipients) diff --git a/src/synthorg/tools/communication/notification_sender.py b/src/synthorg/tools/communication/notification_sender.py new file mode 100644 index 0000000000..2a3d40d184 --- /dev/null +++ b/src/synthorg/tools/communication/notification_sender.py @@ -0,0 +1,266 @@ +"""Notification sender tool -- dispatch notifications via the existing subsystem. + +Delegates to the ``NotificationDispatcher`` from +``synthorg.notifications``, which fans out to all configured +sinks (console, email, Slack, ntfy, etc.). +""" + +import copy +from datetime import UTC, datetime +from typing import Any, Final, Protocol, runtime_checkable + +from pydantic import ValidationError + +from synthorg.core.enums import ActionType +from synthorg.notifications.models import ( + Notification, + NotificationCategory, + NotificationSeverity, +) +from synthorg.observability import get_logger +from synthorg.observability.events.communication import ( + COMM_TOOL_NOTIFICATION_SEND_FAILED, + COMM_TOOL_NOTIFICATION_SEND_START, + COMM_TOOL_NOTIFICATION_SEND_SUCCESS, +) +from synthorg.tools.base import ToolExecutionResult +from synthorg.tools.communication.base_communication_tool import ( + BaseCommunicationTool, +) +from synthorg.tools.communication.config import ( + CommunicationToolsConfig, # noqa: TC001 +) + + +@runtime_checkable +class NotificationDispatcherProtocol(Protocol): + """Protocol for notification dispatch -- matches ``NotificationDispatcher``.""" + + async def dispatch(self, notification: Notification) -> None: + """Dispatch a notification to all registered sinks.""" + ... + + +logger = get_logger(__name__) + +_VALID_CATEGORIES: Final[frozenset[str]] = frozenset( + m.value for m in NotificationCategory +) +_VALID_SEVERITIES: Final[frozenset[str]] = frozenset( + m.value for m in NotificationSeverity +) + +_PARAMETERS_SCHEMA: Final[dict[str, Any]] = { + "type": "object", + "properties": { + "category": { + "type": "string", + "enum": sorted(_VALID_CATEGORIES), + "description": "Notification category", + }, + "severity": { + "type": "string", + "enum": sorted(_VALID_SEVERITIES), + "description": "Notification severity level", + }, + "title": { + "type": "string", + "minLength": 1, + "description": "Notification title", + }, + "body": { + "type": "string", + "description": "Detailed notification body", + "default": "", + }, + "source": { + "type": "string", + "minLength": 1, + "description": "Source subsystem or agent name", + }, + }, + "required": ["category", "severity", "title", "source"], + "additionalProperties": False, +} + + +class NotificationSenderTool(BaseCommunicationTool): + """Send notifications via the existing notification subsystem. + + Delegates to the ``NotificationDispatcher`` which fans out + to all registered sinks (console, ntfy, Slack, email). + + Examples: + Send a notification:: + + tool = NotificationSenderTool(dispatcher=my_dispatcher) + result = await tool.execute( + arguments={ + "category": "system", + "severity": "info", + "title": "Deployment complete", + "source": "deploy-agent", + } + ) + """ + + def __init__( + self, + *, + dispatcher: NotificationDispatcherProtocol | None = None, + config: CommunicationToolsConfig | None = None, + ) -> None: + """Initialize the notification sender tool. + + Args: + dispatcher: Notification dispatcher instance. + ``None`` means the tool will return an error. + config: Communication tool configuration. + """ + super().__init__( + name="notification_sender", + description=( + "Send notifications to registered sinks (console, email, Slack, ntfy)." + ), + parameters_schema=copy.deepcopy(_PARAMETERS_SCHEMA), + action_type=ActionType.COMMS_INTERNAL, + config=config, + ) + self._dispatcher = dispatcher + + async def execute( # noqa: PLR0911 + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + """Send a notification. + + Args: + arguments: Must contain ``category``, ``severity``, + ``title``, and ``source``; optionally ``body``. + + Returns: + A ``ToolExecutionResult`` with dispatch status. + """ + if self._dispatcher is None: + logger.warning( + COMM_TOOL_NOTIFICATION_SEND_FAILED, + error="dispatcher_not_configured", + ) + return ToolExecutionResult( + content=( + "Notification sending requires a configured " + "NotificationDispatcher. None was provided." + ), + is_error=True, + ) + + body: str = arguments.get("body", "") + required_fields = { + "category": arguments.get("category"), + "severity": arguments.get("severity"), + "title": arguments.get("title"), + "source": arguments.get("source"), + } + for field_name, field_val in required_fields.items(): + if not isinstance(field_val, str) or not field_val: + logger.warning( + COMM_TOOL_NOTIFICATION_SEND_FAILED, + error="missing_field", + field=field_name, + ) + return ToolExecutionResult( + content=( + f"'{field_name}' is required and must be a non-empty string." + ), + is_error=True, + ) + + category_str: str = required_fields["category"] # type: ignore[assignment] + severity_str: str = required_fields["severity"] # type: ignore[assignment] + title: str = required_fields["title"] # type: ignore[assignment] + source: str = required_fields["source"] # type: ignore[assignment] + + if category_str not in _VALID_CATEGORIES: + logger.warning( + COMM_TOOL_NOTIFICATION_SEND_FAILED, + error="invalid_category", + category=category_str, + ) + return ToolExecutionResult( + content=( + f"Invalid category: {category_str!r}. " + f"Must be one of: {sorted(_VALID_CATEGORIES)}" + ), + is_error=True, + ) + + if severity_str not in _VALID_SEVERITIES: + logger.warning( + COMM_TOOL_NOTIFICATION_SEND_FAILED, + error="invalid_severity", + severity=severity_str, + ) + return ToolExecutionResult( + content=( + f"Invalid severity: {severity_str!r}. " + f"Must be one of: {sorted(_VALID_SEVERITIES)}" + ), + is_error=True, + ) + + try: + notification = Notification( + category=NotificationCategory(category_str), + severity=NotificationSeverity(severity_str), + title=title, + body=body, + source=source, + timestamp=datetime.now(UTC), + ) + except (ValueError, TypeError, ValidationError) as exc: + logger.warning( + COMM_TOOL_NOTIFICATION_SEND_FAILED, + error="invalid_notification_fields", + detail=str(exc), + ) + return ToolExecutionResult( + content=f"Invalid notification fields: {exc}", + is_error=True, + ) + + logger.info( + COMM_TOOL_NOTIFICATION_SEND_START, + notification_id=notification.id, + category=category_str, + severity=severity_str, + ) + + try: + await self._dispatcher.dispatch(notification) + except MemoryError, RecursionError: + raise + except Exception as exc: + logger.warning( + COMM_TOOL_NOTIFICATION_SEND_FAILED, + notification_id=notification.id, + error=str(exc), + ) + return ToolExecutionResult( + content=f"Notification dispatch failed: {exc}", + is_error=True, + ) + + logger.info( + COMM_TOOL_NOTIFICATION_SEND_SUCCESS, + notification_id=notification.id, + ) + + return ToolExecutionResult( + content=(f"Notification dispatched: [{severity_str}] {title}"), + metadata={ + "notification_id": notification.id, + "category": category_str, + "severity": severity_str, + }, + ) diff --git a/src/synthorg/tools/communication/template_formatter.py b/src/synthorg/tools/communication/template_formatter.py new file mode 100644 index 0000000000..7bf22660e8 --- /dev/null +++ b/src/synthorg/tools/communication/template_formatter.py @@ -0,0 +1,203 @@ +"""Template formatter tool -- render message templates safely. + +Uses Jinja2 ``SandboxedEnvironment`` for safe variable substitution +with no arbitrary code execution. +""" + +import copy +from typing import Any, Final + +from jinja2 import TemplateSyntaxError +from jinja2.sandbox import SandboxedEnvironment + +from synthorg.core.enums import ActionType +from synthorg.observability import get_logger +from synthorg.observability.events.communication import ( + COMM_TOOL_TEMPLATE_RENDER_FAILED, + COMM_TOOL_TEMPLATE_RENDER_INVALID, + COMM_TOOL_TEMPLATE_RENDER_START, + COMM_TOOL_TEMPLATE_RENDER_SUCCESS, +) +from synthorg.tools.base import ToolExecutionResult +from synthorg.tools.communication.base_communication_tool import ( + BaseCommunicationTool, +) +from synthorg.tools.communication.config import ( + CommunicationToolsConfig, # noqa: TC001 +) + +logger = get_logger(__name__) + +_OUTPUT_FORMATS: Final[frozenset[str]] = frozenset({"text", "html", "markdown"}) + +_PARAMETERS_SCHEMA: Final[dict[str, Any]] = { + "type": "object", + "properties": { + "template": { + "type": "string", + "description": ("Inline Jinja2 template string (e.g. 'Hello {{ name }}')"), + }, + "variables": { + "type": "object", + "description": "Variable bindings for template rendering", + }, + "format": { + "type": "string", + "enum": sorted(_OUTPUT_FORMATS), + "description": "Output format (default: text)", + "default": "text", + }, + }, + "required": ["template", "variables"], + "additionalProperties": False, +} + + +class TemplateFormatterTool(BaseCommunicationTool): + """Format message templates with safe variable substitution. + + Uses Jinja2 ``SandboxedEnvironment`` to prevent arbitrary + code execution. Only inline templates are supported (no + file-based templates) to avoid path traversal risks. + + Examples: + Render a template:: + + tool = TemplateFormatterTool() + result = await tool.execute( + arguments={ + "template": "Hello {{ name }}, your balance is {{ amount }}.", + "variables": {"name": "Alice", "amount": "$100"}, + } + ) + """ + + def __init__( + self, + *, + config: CommunicationToolsConfig | None = None, + ) -> None: + """Initialize the template formatter tool. + + Args: + config: Communication tool configuration. + """ + super().__init__( + name="template_formatter", + description=( + "Render inline message templates with safe " + "Jinja2 variable substitution." + ), + parameters_schema=copy.deepcopy(_PARAMETERS_SCHEMA), + action_type=ActionType.CODE_READ, + config=config, + ) + self._env = SandboxedEnvironment() + self._env_autoesc = SandboxedEnvironment(autoescape=True) + + async def execute( # noqa: PLR0911 + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + """Render a template with variables. + + Args: + arguments: Must contain ``template`` and ``variables``; + optionally ``format``. + + Returns: + A ``ToolExecutionResult`` with rendered text. + """ + template_str = arguments.get("template") + variables = arguments.get("variables") + if not isinstance(template_str, str): + logger.warning( + COMM_TOOL_TEMPLATE_RENDER_FAILED, + error="missing_or_invalid_template", + ) + return ToolExecutionResult( + content="'template' must be a string.", + is_error=True, + ) + if not isinstance(variables, dict): + logger.warning( + COMM_TOOL_TEMPLATE_RENDER_FAILED, + error="missing_or_invalid_variables", + ) + return ToolExecutionResult( + content="'variables' must be a dict.", + is_error=True, + ) + output_format = arguments.get("format", "text") + if not isinstance(output_format, str): + logger.warning( + COMM_TOOL_TEMPLATE_RENDER_FAILED, + error="invalid_format_type", + ) + return ToolExecutionResult( + content="'format' must be a string.", + is_error=True, + ) + + if output_format not in _OUTPUT_FORMATS: + logger.warning( + COMM_TOOL_TEMPLATE_RENDER_FAILED, + error="invalid_output_format", + output_format=output_format, + ) + return ToolExecutionResult( + content=( + f"Invalid format: {output_format!r}. " + f"Must be one of: {sorted(_OUTPUT_FORMATS)}" + ), + is_error=True, + ) + + logger.info( + COMM_TOOL_TEMPLATE_RENDER_START, + template_length=len(template_str), + variable_count=len(variables), + output_format=output_format, + ) + + env = self._env_autoesc if output_format == "html" else self._env + try: + tmpl = env.from_string(template_str) + except TemplateSyntaxError as exc: + logger.warning( + COMM_TOOL_TEMPLATE_RENDER_INVALID, + error=str(exc), + ) + return ToolExecutionResult( + content=f"Invalid template syntax: {exc}", + is_error=True, + ) + + try: + rendered = tmpl.render(**variables) + except MemoryError, RecursionError: + raise + except Exception as exc: + logger.warning( + COMM_TOOL_TEMPLATE_RENDER_FAILED, + error=str(exc), + ) + return ToolExecutionResult( + content=f"Template rendering failed: {exc}", + is_error=True, + ) + + logger.info( + COMM_TOOL_TEMPLATE_RENDER_SUCCESS, + output_length=len(rendered), + output_format=output_format, + ) + + return ToolExecutionResult( + content=rendered, + metadata={ + "format": output_format, + "output_length": len(rendered), + }, + ) diff --git a/src/synthorg/tools/design/__init__.py b/src/synthorg/tools/design/__init__.py new file mode 100644 index 0000000000..59f1ebb110 --- /dev/null +++ b/src/synthorg/tools/design/__init__.py @@ -0,0 +1,21 @@ +"""Built-in design tools for image generation, diagrams, and asset management.""" + +from synthorg.tools.design.asset_manager import AssetManagerTool +from synthorg.tools.design.base_design_tool import BaseDesignTool +from synthorg.tools.design.config import DesignToolsConfig +from synthorg.tools.design.diagram_generator import DiagramGeneratorTool +from synthorg.tools.design.image_generator import ( + ImageGeneratorTool, + ImageProvider, + ImageResult, +) + +__all__ = [ + "AssetManagerTool", + "BaseDesignTool", + "DesignToolsConfig", + "DiagramGeneratorTool", + "ImageGeneratorTool", + "ImageProvider", + "ImageResult", +] diff --git a/src/synthorg/tools/design/asset_manager.py b/src/synthorg/tools/design/asset_manager.py new file mode 100644 index 0000000000..6160292b7f --- /dev/null +++ b/src/synthorg/tools/design/asset_manager.py @@ -0,0 +1,365 @@ +"""Asset manager tool -- manage generated design assets. + +Provides CRUD operations on an in-memory asset registry that +tracks metadata for generated images, diagrams, and other +design artifacts. +""" + +import copy +from typing import Any, Final + +from synthorg.core.enums import ActionType +from synthorg.observability import get_logger +from synthorg.observability.events.design import ( + DESIGN_ASSET_DELETED, + DESIGN_ASSET_LISTED, + DESIGN_ASSET_RETRIEVED, + DESIGN_ASSET_SEARCHED, + DESIGN_ASSET_STORED, + DESIGN_ASSET_VALIDATION_FAILED, +) +from synthorg.tools.base import ToolExecutionResult +from synthorg.tools.design.base_design_tool import BaseDesignTool +from synthorg.tools.design.config import DesignToolsConfig # noqa: TC001 + +logger = get_logger(__name__) + +_VALID_ACTIONS: Final[frozenset[str]] = frozenset( + { + "list", + "get", + "delete", + "search", + } +) + +_PARAMETERS_SCHEMA: Final[dict[str, Any]] = { + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": sorted(_VALID_ACTIONS), + "description": "Asset operation to perform", + }, + "asset_id": { + "type": "string", + "description": "Asset identifier (required for get/delete)", + }, + "tags": { + "type": "array", + "items": {"type": "string"}, + "description": "Tags for filtering (used with list/search)", + }, + "query": { + "type": "string", + "description": "Search query for asset metadata", + }, + }, + "required": ["action"], + "additionalProperties": False, +} + + +class AssetManagerTool(BaseDesignTool): + """Manage generated design assets (list, get, delete, search). + + Maintains an in-memory registry of asset metadata. Assets + are registered by other design tools (e.g. ``ImageGeneratorTool``) + and can be queried or removed through this tool. + + Examples: + List all assets:: + + tool = AssetManagerTool() + result = await tool.execute(arguments={"action": "list"}) + + Get a specific asset:: + + result = await tool.execute( + arguments={"action": "get", "asset_id": "img-001"} + ) + """ + + def __init__( + self, + *, + config: DesignToolsConfig | None = None, + assets: dict[str, dict[str, Any]] | None = None, + ) -> None: + """Initialize the asset manager tool. + + Args: + config: Design tool configuration. + assets: Pre-populated asset registry. ``None`` starts + with an empty registry. + """ + super().__init__( + name="asset_manager", + description=("List, retrieve, delete, and search generated design assets."), + parameters_schema=copy.deepcopy(_PARAMETERS_SCHEMA), + action_type=ActionType.DOCS_WRITE, + config=config, + ) + self._assets: dict[str, dict[str, Any]] = ( + copy.deepcopy(assets) if assets else {} + ) + + def register_asset( + self, + asset_id: str, + metadata: dict[str, Any], + ) -> None: + """Register an asset in the internal registry. + + Called programmatically by other tools after generating + an asset. + + Args: + asset_id: Unique asset identifier. + metadata: Asset metadata (type, dimensions, tags, etc.). + + Raises: + ValueError: If asset_id is empty or whitespace-only. + """ + if not asset_id.strip(): + msg = "asset_id must not be empty" + raise ValueError(msg) + self._assets[asset_id] = copy.deepcopy(metadata) + logger.info( + DESIGN_ASSET_STORED, + asset_id=asset_id, + asset_type=metadata.get("type", "unknown"), + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + """Execute an asset management operation. + + Args: + arguments: Must contain ``action``; optionally + ``asset_id``, ``tags``, ``query``. + + Returns: + A ``ToolExecutionResult`` with operation results. + """ + action = arguments.get("action") + if not isinstance(action, str): + logger.warning( + DESIGN_ASSET_VALIDATION_FAILED, + reason="missing_action", + ) + return ToolExecutionResult( + content="'action' is required and must be a string.", + is_error=True, + ) + + if action not in _VALID_ACTIONS: + logger.warning( + DESIGN_ASSET_VALIDATION_FAILED, + action=action, + reason="invalid_action", + ) + return ToolExecutionResult( + content=( + f"Invalid action: {action!r}. " + f"Must be one of: {sorted(_VALID_ACTIONS)}" + ), + is_error=True, + ) + + if action == "list": + return self._handle_list(arguments) + if action == "get": + return self._handle_get(arguments) + if action == "delete": + return self._handle_delete(arguments) + return self._handle_search(arguments) + + def _handle_list( + self, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + """List assets, optionally filtered by tags.""" + raw_tags = arguments.get("tags") + if raw_tags is not None and not isinstance(raw_tags, list): + logger.debug( + DESIGN_ASSET_VALIDATION_FAILED, + action="list", + reason="invalid_tags_type", + ) + raw_list = raw_tags if isinstance(raw_tags, list) else [] + tags = [t for t in raw_list if isinstance(t, str)] + tag_set = set(tags) + + if tag_set: + matching = { + aid: meta + for aid, meta in self._assets.items() + if tag_set.issubset( + {t for t in (meta.get("tags") or []) if isinstance(t, str)} + ) + } + else: + matching = self._assets + + logger.info( + DESIGN_ASSET_LISTED, + total=len(self._assets), + matched=len(matching), + filter_tags=tags, + ) + + if not matching: + return ToolExecutionResult(content="No assets found.") + + lines = [f"Found {len(matching)} asset(s):"] + for aid, meta in sorted(matching.items()): + asset_type = meta.get("type", "unknown") + asset_tags = meta.get("tags", []) + lines.append(f" - {aid}: type={asset_type}, tags={asset_tags}") + return ToolExecutionResult( + content="\n".join(lines), + metadata={"count": len(matching)}, + ) + + def _handle_get( + self, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + """Retrieve a specific asset by ID.""" + asset_id = arguments.get("asset_id") + if not isinstance(asset_id, str) or not asset_id.strip(): + logger.warning( + DESIGN_ASSET_VALIDATION_FAILED, + action="get", + reason="missing_asset_id", + ) + return ToolExecutionResult( + content="asset_id is required for 'get' action.", + is_error=True, + ) + + meta = self._assets.get(asset_id) + if meta is None: + logger.warning( + DESIGN_ASSET_VALIDATION_FAILED, + action="get", + reason="not_found", + asset_id=asset_id, + ) + return ToolExecutionResult( + content=f"Asset not found: {asset_id!r}", + is_error=True, + ) + + logger.info( + DESIGN_ASSET_RETRIEVED, + asset_id=asset_id, + ) + + lines = [f"Asset: {asset_id}"] + for key, value in sorted(meta.items()): + lines.append(f" {key}: {value}") + return ToolExecutionResult( + content="\n".join(lines), + metadata=copy.deepcopy(meta), + ) + + def _handle_delete( + self, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + """Delete an asset by ID.""" + asset_id = arguments.get("asset_id") + if not isinstance(asset_id, str) or not asset_id.strip(): + logger.warning( + DESIGN_ASSET_VALIDATION_FAILED, + action="delete", + reason="missing_asset_id", + ) + return ToolExecutionResult( + content="asset_id is required for 'delete' action.", + is_error=True, + ) + + if asset_id not in self._assets: + logger.warning( + DESIGN_ASSET_VALIDATION_FAILED, + action="delete", + reason="not_found", + asset_id=asset_id, + ) + return ToolExecutionResult( + content=f"Asset not found: {asset_id!r}", + is_error=True, + ) + + del self._assets[asset_id] + + logger.info( + DESIGN_ASSET_DELETED, + asset_id=asset_id, + ) + + return ToolExecutionResult( + content=f"Asset deleted: {asset_id}", + ) + + def _handle_search( + self, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + """Search assets by query string in metadata values.""" + raw_query = arguments.get("query") + if not isinstance(raw_query, str) or not raw_query.strip(): + logger.warning( + DESIGN_ASSET_VALIDATION_FAILED, + action="search", + reason="missing_query", + ) + return ToolExecutionResult( + content="query is required for 'search' action.", + is_error=True, + ) + + query = raw_query.strip().lower() + raw_tags = arguments.get("tags") + raw_list = raw_tags if isinstance(raw_tags, list) else [] + tags = [t for t in raw_list if isinstance(t, str)] + tag_set = set(tags) + + matching: dict[str, dict[str, Any]] = {} + for aid, meta in self._assets.items(): + searchable = " ".join(str(v).lower() for v in meta.values()) + if query not in searchable: + continue + if tag_set and not tag_set.issubset( + {t for t in (meta.get("tags") or []) if isinstance(t, str)} + ): + continue + matching[aid] = meta + + logger.info( + DESIGN_ASSET_SEARCHED, + total=len(self._assets), + matched=len(matching), + search_query=query, + filter_tags=tags, + ) + + if not matching: + return ToolExecutionResult( + content=f"No assets matching query: {query!r}", + ) + + lines = [f"Found {len(matching)} asset(s) matching {query!r}:"] + for aid, meta in sorted(matching.items()): + asset_type = meta.get("type", "unknown") + lines.append(f" - {aid}: type={asset_type}") + return ToolExecutionResult( + content="\n".join(lines), + metadata={"count": len(matching)}, + ) diff --git a/src/synthorg/tools/design/base_design_tool.py b/src/synthorg/tools/design/base_design_tool.py new file mode 100644 index 0000000000..56ba15cb1d --- /dev/null +++ b/src/synthorg/tools/design/base_design_tool.py @@ -0,0 +1,52 @@ +"""Base class for design tools. + +Provides the common ``ToolCategory.DESIGN`` category and +a shared configuration reference. +""" + +from abc import ABC +from typing import Any + +from synthorg.core.enums import ToolCategory +from synthorg.tools.base import BaseTool +from synthorg.tools.design.config import DesignToolsConfig + + +class BaseDesignTool(BaseTool, ABC): + """Abstract base for all design tools. + + Sets ``category=ToolCategory.DESIGN`` and holds a shared + ``DesignToolsConfig``. + """ + + def __init__( + self, + *, + name: str, + description: str = "", + parameters_schema: dict[str, Any] | None = None, + action_type: str | None = None, + config: DesignToolsConfig | None = None, + ) -> None: + """Initialize a design tool with configuration. + + Args: + name: Tool name. + description: Human-readable description. + parameters_schema: JSON Schema for tool parameters. + action_type: Security action type override. + config: Design tool configuration. + """ + super().__init__( + name=name, + description=description, + category=ToolCategory.DESIGN, + parameters_schema=parameters_schema, + action_type=action_type, + ) + self._config = config or DesignToolsConfig() + + @property + def config(self) -> DesignToolsConfig: + """The design tool configuration.""" + return self._config diff --git a/src/synthorg/tools/design/config.py b/src/synthorg/tools/design/config.py new file mode 100644 index 0000000000..64042779c3 --- /dev/null +++ b/src/synthorg/tools/design/config.py @@ -0,0 +1,34 @@ +"""Configuration models for design tools.""" + +from pydantic import BaseModel, ConfigDict, Field + +from synthorg.core.types import NotBlankStr # noqa: TC001 + + +class DesignToolsConfig(BaseModel): + """Top-level configuration for design tools. + + Attributes: + image_timeout: Timeout for image generation in seconds. + max_image_size_bytes: Maximum image output size in bytes. + asset_storage_path: Optional filesystem path for storing + generated assets. ``None`` means in-memory only. + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + image_timeout: float = Field( + default=60.0, + gt=0, + le=600.0, + description="Image generation timeout (seconds)", + ) + max_image_size_bytes: int = Field( + default=52_428_800, + gt=0, + description="Maximum image output size (bytes, default 50 MB)", + ) + asset_storage_path: NotBlankStr | None = Field( + default=None, + description="Filesystem path for asset storage (None = in-memory)", + ) diff --git a/src/synthorg/tools/design/diagram_generator.py b/src/synthorg/tools/design/diagram_generator.py new file mode 100644 index 0000000000..faa42ecf49 --- /dev/null +++ b/src/synthorg/tools/design/diagram_generator.py @@ -0,0 +1,273 @@ +"""Diagram generator tool -- generate Mermaid/Graphviz DSL from descriptions. + +Produces diagram markup (Mermaid or Graphviz DOT) that can be rendered +by downstream tools or the web dashboard. No external provider is +required -- the tool outputs DSL text directly. +""" + +import copy +from typing import Any, Final + +from synthorg.core.enums import ActionType +from synthorg.observability import get_logger +from synthorg.observability.events.design import ( + DESIGN_DIAGRAM_GENERATION_FAILED, + DESIGN_DIAGRAM_GENERATION_START, + DESIGN_DIAGRAM_GENERATION_SUCCESS, +) +from synthorg.tools.base import ToolExecutionResult +from synthorg.tools.design.base_design_tool import BaseDesignTool +from synthorg.tools.design.config import DesignToolsConfig # noqa: TC001 + +logger = get_logger(__name__) + +_DIAGRAM_TYPES: Final[frozenset[str]] = frozenset( + { + "flowchart", + "sequence", + "class", + "state", + "architecture", + } +) + +_OUTPUT_FORMATS: Final[frozenset[str]] = frozenset( + { + "mermaid", + "graphviz", + } +) + +_PARAMETERS_SCHEMA: Final[dict[str, Any]] = { + "type": "object", + "properties": { + "diagram_type": { + "type": "string", + "enum": sorted(_DIAGRAM_TYPES), + "description": "Type of diagram to generate", + }, + "description": { + "type": "string", + "description": ( + "Diagram specification -- structured description " + "of nodes, edges, and relationships" + ), + }, + "title": { + "type": "string", + "description": "Optional diagram title", + }, + "output_format": { + "type": "string", + "enum": sorted(_OUTPUT_FORMATS), + "description": "Output markup format (default: mermaid)", + "default": "mermaid", + }, + }, + "required": ["diagram_type", "description"], + "additionalProperties": False, +} + + +class DiagramGeneratorTool(BaseDesignTool): + """Generate diagram markup (Mermaid/Graphviz) from structured descriptions. + + Produces DSL text that can be rendered by Mermaid.js, Graphviz, + or the web dashboard. No external API is needed. + + Examples: + Generate a flowchart:: + + tool = DiagramGeneratorTool() + result = await tool.execute( + arguments={ + "diagram_type": "flowchart", + "description": "A -> B -> C", + "title": "Simple Flow", + } + ) + """ + + def __init__( + self, + *, + config: DesignToolsConfig | None = None, + ) -> None: + """Initialize the diagram generator tool. + + Args: + config: Design tool configuration. + """ + super().__init__( + name="diagram_generator", + description=( + "Generate diagram markup (Mermaid or Graphviz) " + "from structured descriptions." + ), + parameters_schema=copy.deepcopy(_PARAMETERS_SCHEMA), + action_type=ActionType.DOCS_WRITE, + config=config, + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + """Generate diagram markup from a description. + + Args: + arguments: Must contain ``diagram_type`` and + ``description``; optionally ``title`` and + ``output_format``. + + Returns: + A ``ToolExecutionResult`` with the diagram DSL. + """ + diagram_type: str = arguments["diagram_type"] + description: str = arguments["description"] + title: str = arguments.get("title", "") + output_format: str = arguments.get("output_format", "mermaid") + + if diagram_type not in _DIAGRAM_TYPES: + logger.warning( + DESIGN_DIAGRAM_GENERATION_FAILED, + error="invalid_diagram_type", + diagram_type=diagram_type, + ) + return ToolExecutionResult( + content=( + f"Invalid diagram_type: {diagram_type!r}. " + f"Must be one of: {sorted(_DIAGRAM_TYPES)}" + ), + is_error=True, + ) + + if output_format not in _OUTPUT_FORMATS: + logger.warning( + DESIGN_DIAGRAM_GENERATION_FAILED, + error="invalid_output_format", + output_format=output_format, + ) + return ToolExecutionResult( + content=( + f"Invalid output_format: {output_format!r}. " + f"Must be one of: {sorted(_OUTPUT_FORMATS)}" + ), + is_error=True, + ) + + logger.info( + DESIGN_DIAGRAM_GENERATION_START, + diagram_type=diagram_type, + output_format=output_format, + description_length=len(description), + ) + + try: + if output_format == "mermaid": + markup = self._generate_mermaid(diagram_type, description, title) + else: + markup = self._generate_graphviz(diagram_type, description, title) + except MemoryError, RecursionError: + raise + except Exception: + logger.warning( + DESIGN_DIAGRAM_GENERATION_FAILED, + error="internal_error", + diagram_type=diagram_type, + exc_info=True, + ) + return ToolExecutionResult( + content="Diagram generation failed.", + is_error=True, + ) + + logger.info( + DESIGN_DIAGRAM_GENERATION_SUCCESS, + diagram_type=diagram_type, + output_format=output_format, + markup_length=len(markup), + ) + + return ToolExecutionResult( + content=markup, + metadata={ + "diagram_type": diagram_type, + "output_format": output_format, + "title": title, + }, + ) + + @staticmethod + def _generate_mermaid( + diagram_type: str, + description: str, + title: str, + ) -> str: + """Generate Mermaid DSL from the description. + + Wraps the user-provided description in the appropriate + Mermaid diagram directive. + + Args: + diagram_type: Type of diagram. + description: User-provided diagram specification. + title: Optional title. + + Returns: + Mermaid markup string. + """ + type_map: dict[str, str] = { + "flowchart": "flowchart TD", + "sequence": "sequenceDiagram", + "class": "classDiagram", + "state": "stateDiagram-v2", + "architecture": "flowchart TD", + } + directive = type_map.get(diagram_type, "flowchart TD") + lines: list[str] = [] + if title: + safe_title = ( + title.replace("\\", "\\\\") + .replace('"', '\\"') + .replace("\r", " ") + .replace("\n", " ") + ) + lines.append("---") + lines.append(f'title: "{safe_title}"') + lines.append("---") + lines.append(directive) + lines.extend(f" {line}" for line in description.strip().splitlines()) + return "\n".join(lines) + + @staticmethod + def _generate_graphviz( + diagram_type: str, + description: str, + title: str, + ) -> str: + """Generate Graphviz DOT from the description. + + Wraps the user-provided description in a DOT digraph block. + + Args: + diagram_type: Type of diagram (used for graph attributes). + description: User-provided diagram specification. + title: Optional title. + + Returns: + Graphviz DOT string. + """ + graph_type = "graph" if diagram_type == "architecture" else "digraph" + if title: + escaped = ( + title.replace("\\", "\\\\") + .replace('"', '\\"') + .replace("\r", "") + .replace("\n", "\\n") + ) + label = f' label="{escaped}";\n' + else: + label = "" + return f"{graph_type} {diagram_type} {{\n{label} {description}\n}}" diff --git a/src/synthorg/tools/design/image_generator.py b/src/synthorg/tools/design/image_generator.py new file mode 100644 index 0000000000..f0fd2e9362 --- /dev/null +++ b/src/synthorg/tools/design/image_generator.py @@ -0,0 +1,343 @@ +"""Image generator tool -- generate images via an abstracted provider. + +The ``ImageProvider`` protocol defines a vendor-agnostic interface +for image generation. No concrete implementation is shipped -- users +inject a provider at construction time. +""" + +import asyncio +import base64 +import copy +from typing import Any, Final, Protocol, runtime_checkable + +from pydantic import BaseModel, ConfigDict, Field + +from synthorg.core.enums import ActionType +from synthorg.observability import get_logger +from synthorg.observability.events.design import ( + DESIGN_IMAGE_GENERATION_FAILED, + DESIGN_IMAGE_GENERATION_START, + DESIGN_IMAGE_GENERATION_SUCCESS, + DESIGN_IMAGE_GENERATION_TIMEOUT, + DESIGN_PROVIDER_NOT_CONFIGURED, +) +from synthorg.tools.base import ToolExecutionResult +from synthorg.tools.design.base_design_tool import BaseDesignTool +from synthorg.tools.design.config import DesignToolsConfig # noqa: TC001 + +logger = get_logger(__name__) + + +class ImageResult(BaseModel): + """Result from an image generation provider. + + Attributes: + data: Raw image bytes (base64-encoded string). + content_type: MIME type of the generated image. + width: Image width in pixels. + height: Image height in pixels. + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + data: str = Field(min_length=1, description="Base64-encoded image data") + content_type: str = Field( + default="image/png", + description="MIME type of the generated image", + ) + width: int = Field(gt=0, description="Image width in pixels") + height: int = Field(gt=0, description="Image height in pixels") + + +@runtime_checkable +class ImageProvider(Protocol): + """Abstracted image generation provider protocol. + + Implementations must be async and return an ``ImageResult``. + """ + + async def generate( + self, + *, + prompt: str, + width: int = 1024, + height: int = 1024, + style: str = "realistic", + quality: str = "standard", + ) -> ImageResult: + """Generate an image from a text prompt. + + Args: + prompt: Image description. + width: Image width in pixels. + height: Image height in pixels. + style: Image style preset. + quality: Image quality preset. + + Returns: + Generated image result. + """ + ... + + +_PARAMETERS_SCHEMA: Final[dict[str, Any]] = { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "Image description", + }, + "style": { + "type": "string", + "enum": ["realistic", "sketch", "diagram", "icon"], + "description": "Image style (default: realistic)", + "default": "realistic", + }, + "width": { + "type": "integer", + "description": "Image width in pixels", + "minimum": 256, + "maximum": 2048, + "default": 1024, + }, + "height": { + "type": "integer", + "description": "Image height in pixels", + "minimum": 256, + "maximum": 2048, + "default": 1024, + }, + "quality": { + "type": "string", + "enum": ["draft", "standard", "high"], + "description": "Image quality preset", + "default": "standard", + }, + }, + "required": ["prompt"], + "additionalProperties": False, +} + +_VALID_STYLES: Final[frozenset[str]] = frozenset( + {"realistic", "sketch", "diagram", "icon"} +) +_VALID_QUALITIES: Final[frozenset[str]] = frozenset({"draft", "standard", "high"}) + +_MIN_DIMENSION: Final[int] = 256 +_MAX_DIMENSION: Final[int] = 2048 + + +class ImageGeneratorTool(BaseDesignTool): + """Generate images from text prompts via an abstracted provider. + + Requires an ``ImageProvider`` implementation to be injected at + construction time. If no provider is configured, the tool + returns an error explaining the requirement. + + Examples: + Generate an image:: + + tool = ImageGeneratorTool(provider=my_provider) + result = await tool.execute(arguments={"prompt": "A sunset over mountains"}) + """ + + def __init__( + self, + *, + provider: ImageProvider | None = None, + config: DesignToolsConfig | None = None, + ) -> None: + """Initialize the image generator tool. + + Args: + provider: Image generation provider. ``None`` means + the tool will return an error on execution. + config: Design tool configuration. + """ + super().__init__( + name="image_generator", + description=( + "Generate images from text prompts. Supports style and quality presets." + ), + parameters_schema=copy.deepcopy(_PARAMETERS_SCHEMA), + action_type=ActionType.DOCS_WRITE, + config=config, + ) + self._provider = provider + + async def execute( # noqa: PLR0911 + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + """Generate an image from a text prompt. + + Args: + arguments: Must contain ``prompt``; optionally ``style``, + ``width``, ``height``, ``quality``. + + Returns: + A ``ToolExecutionResult`` with image data or error. + """ + if self._provider is None: + logger.warning( + DESIGN_PROVIDER_NOT_CONFIGURED, + tool="image_generator", + ) + return ToolExecutionResult( + content=( + "Image generation requires a configured provider. " + "No ImageProvider has been injected." + ), + is_error=True, + ) + + prompt: str = arguments["prompt"] + style: str = arguments.get("style", "realistic") + width: int = arguments.get("width", 1024) + height: int = arguments.get("height", 1024) + quality: str = arguments.get("quality", "standard") + + if not (_MIN_DIMENSION <= width <= _MAX_DIMENSION) or not ( + _MIN_DIMENSION <= height <= _MAX_DIMENSION + ): + logger.warning( + DESIGN_IMAGE_GENERATION_FAILED, + error="invalid_dimensions", + width=width, + height=height, + ) + return ToolExecutionResult( + content=( + f"Width and height must be between " + f"{_MIN_DIMENSION} and {_MAX_DIMENSION}. " + f"Got width={width}, height={height}." + ), + is_error=True, + ) + + if style not in _VALID_STYLES: + logger.warning( + DESIGN_IMAGE_GENERATION_FAILED, + error="invalid_style", + style=style, + ) + return ToolExecutionResult( + content=( + f"Invalid style: {style!r}. Must be one of: {sorted(_VALID_STYLES)}" + ), + is_error=True, + ) + + if quality not in _VALID_QUALITIES: + logger.warning( + DESIGN_IMAGE_GENERATION_FAILED, + error="invalid_quality", + quality=quality, + ) + return ToolExecutionResult( + content=( + f"Invalid quality: {quality!r}. " + f"Must be one of: {sorted(_VALID_QUALITIES)}" + ), + is_error=True, + ) + + logger.info( + DESIGN_IMAGE_GENERATION_START, + prompt_length=len(prompt), + style=style, + width=width, + height=height, + quality=quality, + ) + + try: + result = await asyncio.wait_for( + self._provider.generate( + prompt=prompt, + width=width, + height=height, + style=style, + quality=quality, + ), + timeout=self._config.image_timeout, + ) + except TimeoutError: + logger.warning( + DESIGN_IMAGE_GENERATION_TIMEOUT, + timeout=self._config.image_timeout, + prompt_length=len(prompt), + ) + return ToolExecutionResult( + content=( + f"Image generation timed out after {self._config.image_timeout}s" + ), + is_error=True, + ) + except MemoryError, RecursionError: + raise + except Exception: + logger.warning( + DESIGN_IMAGE_GENERATION_FAILED, + error="provider_error", + prompt_length=len(prompt), + style=style, + exc_info=True, + ) + return ToolExecutionResult( + content="Image generation failed.", + is_error=True, + ) + + try: + decoded_bytes = base64.b64decode(result.data, validate=True) + except Exception as decode_exc: + logger.warning( + DESIGN_IMAGE_GENERATION_FAILED, + error="invalid_base64", + detail=str(decode_exc), + ) + return ToolExecutionResult( + content=(f"Provider returned invalid base64 image data: {decode_exc}"), + is_error=True, + ) + byte_size = len(decoded_bytes) + if byte_size > self._config.max_image_size_bytes: + logger.warning( + DESIGN_IMAGE_GENERATION_FAILED, + error="image_too_large", + byte_size=byte_size, + max_size=self._config.max_image_size_bytes, + ) + return ToolExecutionResult( + content=( + f"Generated image exceeds size limit: " + f"{byte_size} bytes " + f"(max {self._config.max_image_size_bytes})" + ), + is_error=True, + ) + + logger.info( + DESIGN_IMAGE_GENERATION_SUCCESS, + width=result.width, + height=result.height, + content_type=result.content_type, + data_length=len(result.data), + ) + + return ToolExecutionResult( + content=( + f"Image generated successfully.\n" + f"Dimensions: {result.width}x{result.height}\n" + f"Type: {result.content_type}\n" + f"Data length: {len(result.data)} chars (base64)" + ), + metadata={ + "data": result.data, + "content_type": result.content_type, + "width": result.width, + "height": result.height, + }, + ) diff --git a/src/synthorg/tools/factory.py b/src/synthorg/tools/factory.py index dda34173ad..4e3e8d9224 100644 --- a/src/synthorg/tools/factory.py +++ b/src/synthorg/tools/factory.py @@ -43,8 +43,17 @@ from pathlib import Path from synthorg.config.schema import RootConfig + from synthorg.tools.analytics.config import AnalyticsToolsConfig + from synthorg.tools.analytics.data_aggregator import AnalyticsProvider + from synthorg.tools.analytics.metric_collector import MetricSink from synthorg.tools.base import BaseTool + from synthorg.tools.communication.config import CommunicationToolsConfig + from synthorg.tools.communication.notification_sender import ( + NotificationDispatcherProtocol, + ) from synthorg.tools.database.config import DatabaseConfig, DatabaseConnectionConfig + from synthorg.tools.design.config import DesignToolsConfig + from synthorg.tools.design.image_generator import ImageProvider from synthorg.tools.git_url_validator import GitCloneNetworkPolicy from synthorg.tools.network_validator import NetworkPolicy from synthorg.tools.sandbox.protocol import SandboxBackend @@ -153,6 +162,84 @@ def _build_terminal_tools( return (ShellCommandTool(sandbox=sandbox, config=config),) +def _build_design_tools( + *, + config: DesignToolsConfig | None = None, + image_provider: ImageProvider | None = None, +) -> tuple[BaseTool, ...]: + """Instantiate the built-in design tools. + + Returns an empty tuple when *config* is ``None``. + """ + if config is None: + return () + from synthorg.tools.design import ( # noqa: PLC0415 + AssetManagerTool, + DiagramGeneratorTool, + ImageGeneratorTool, + ) + + tools: list[BaseTool] = [ + DiagramGeneratorTool(config=config), + AssetManagerTool(config=config), + ] + if image_provider is not None: + tools.append(ImageGeneratorTool(provider=image_provider, config=config)) + return tuple(tools) + + +def _build_communication_tools( + *, + config: CommunicationToolsConfig | None = None, + dispatcher: NotificationDispatcherProtocol | None = None, +) -> tuple[BaseTool, ...]: + """Instantiate the built-in communication tools. + + Returns an empty tuple when *config* is ``None``. + """ + if config is None: + return () + from synthorg.tools.communication import ( # noqa: PLC0415 + EmailSenderTool, + NotificationSenderTool, + TemplateFormatterTool, + ) + + tools: list[BaseTool] = [TemplateFormatterTool(config=config)] + if config.email is not None: + tools.append(EmailSenderTool(config=config)) + if dispatcher is not None: + tools.append(NotificationSenderTool(dispatcher=dispatcher, config=config)) + return tuple(tools) + + +def _build_analytics_tools( + *, + config: AnalyticsToolsConfig | None = None, + provider: AnalyticsProvider | None = None, + metric_sink: MetricSink | None = None, +) -> tuple[BaseTool, ...]: + """Instantiate the built-in analytics tools. + + Returns an empty tuple when *config* is ``None``. + """ + if config is None: + return () + from synthorg.tools.analytics import ( # noqa: PLC0415 + DataAggregatorTool, + MetricCollectorTool, + ReportGeneratorTool, + ) + + tools: list[BaseTool] = [] + if provider is not None: + tools.append(DataAggregatorTool(provider=provider, config=config)) + tools.append(ReportGeneratorTool(provider=provider, config=config)) + if metric_sink is not None: + tools.append(MetricCollectorTool(sink=metric_sink, config=config)) + return tuple(tools) + + def build_default_tools( # noqa: PLR0913 *, workspace: Path, @@ -163,6 +250,13 @@ def build_default_tools( # noqa: PLR0913 database_config: DatabaseConfig | None = None, terminal_config: TerminalConfig | None = None, terminal_sandbox: SandboxBackend | None = None, + design_config: DesignToolsConfig | None = None, + image_provider: ImageProvider | None = None, + communication_config: CommunicationToolsConfig | None = None, + communication_dispatcher: NotificationDispatcherProtocol | None = None, + analytics_config: AnalyticsToolsConfig | None = None, + analytics_provider: AnalyticsProvider | None = None, + metric_sink: MetricSink | None = None, ) -> tuple[BaseTool, ...]: """Instantiate all built-in workspace tools. @@ -179,6 +273,17 @@ def build_default_tools( # noqa: PLR0913 database tool creation. terminal_config: Terminal tool configuration. terminal_sandbox: Sandbox backend for terminal tools. + design_config: Design tool configuration. ``None`` skips + design tool creation. + image_provider: Image generation provider for design tools. + communication_config: Communication tool configuration. + ``None`` skips communication tool creation. + communication_dispatcher: Notification dispatcher for the + notification sender tool. + analytics_config: Analytics tool configuration. ``None`` + skips analytics tool creation. + analytics_provider: Analytics data provider. + metric_sink: Metric recording sink. Returns: Sorted tuple of ``BaseTool`` instances. @@ -221,6 +326,26 @@ def build_default_tools( # noqa: PLR0913 _build_database_tools(config=database_config), ) + all_tools.extend( + _build_design_tools( + config=design_config, + image_provider=image_provider, + ), + ) + all_tools.extend( + _build_communication_tools( + config=communication_config, + dispatcher=communication_dispatcher, + ), + ) + all_tools.extend( + _build_analytics_tools( + config=analytics_config, + provider=analytics_provider, + metric_sink=metric_sink, + ), + ) + result = tuple(sorted(all_tools, key=lambda t: t.name)) policy = git_clone_policy @@ -236,12 +361,16 @@ def build_default_tools( # noqa: PLR0913 return result -def build_default_tools_from_config( +def build_default_tools_from_config( # noqa: PLR0913 *, workspace: Path, config: RootConfig, sandbox_backends: Mapping[str, SandboxBackend] | None = None, web_search_provider: WebSearchProvider | None = None, + image_provider: ImageProvider | None = None, + communication_dispatcher: NotificationDispatcherProtocol | None = None, + analytics_provider: AnalyticsProvider | None = None, + metric_sink: MetricSink | None = None, ) -> tuple[BaseTool, ...]: """Build default tools using parameters from a ``RootConfig``. @@ -261,6 +390,12 @@ def build_default_tools_from_config( instead of auto-building backends. web_search_provider: Optional web search provider to inject into the web search tool. + image_provider: Optional image generation provider for design + tools. + communication_dispatcher: Optional notification dispatcher for + the notification sender tool. + analytics_provider: Optional analytics data provider. + metric_sink: Optional metric recording sink. Returns: Sorted tuple of ``BaseTool`` instances. @@ -323,4 +458,11 @@ def build_default_tools_from_config( database_config=config.database, terminal_config=config.terminal, terminal_sandbox=terminal_sandbox, + design_config=config.design_tools, + image_provider=image_provider, + communication_config=config.communication_tools, + communication_dispatcher=communication_dispatcher, + analytics_config=config.analytics_tools, + analytics_provider=analytics_provider, + metric_sink=metric_sink, ) diff --git a/tests/unit/config/conftest.py b/tests/unit/config/conftest.py index 5e3a98d5ac..0d1b29a3e4 100644 --- a/tests/unit/config/conftest.py +++ b/tests/unit/config/conftest.py @@ -106,6 +106,9 @@ class RootConfigFactory(ModelFactory[RootConfig]): coordination = CoordinationSectionConfig() backup = BackupConfig() workflow = WorkflowConfig() + design_tools = None + communication_tools = None + analytics_tools = None # ── Sample YAML strings ────────────────────────────────────────── diff --git a/tests/unit/observability/test_events.py b/tests/unit/observability/test_events.py index a5ee01208a..7d154c29ce 100644 --- a/tests/unit/observability/test_events.py +++ b/tests/unit/observability/test_events.py @@ -198,6 +198,7 @@ def test_all_domain_modules_discovered(self) -> None: "decomposition", "degradation", "delegation", + "design", "docker", "evaluation", "execution", diff --git a/tests/unit/tools/analytics/__init__.py b/tests/unit/tools/analytics/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/tools/analytics/conftest.py b/tests/unit/tools/analytics/conftest.py new file mode 100644 index 0000000000..6c8ba7780c --- /dev/null +++ b/tests/unit/tools/analytics/conftest.py @@ -0,0 +1,112 @@ +"""Shared fixtures for analytics tool tests.""" + +from typing import Any + +import pytest + +from synthorg.tools.analytics.config import AnalyticsToolsConfig + + +class MockAnalyticsProvider: + """Mock analytics provider for testing.""" + + def __init__( + self, + *, + result: dict[str, Any] | None = None, + error: Exception | None = None, + ) -> None: + self._result = ( + { + "total_cost": 1234.56, + "task_count": 42, + "active_agents": 5, + } + if result is None + else result + ) + self._error = error + self.calls: list[dict[str, Any]] = [] + + async def query( + self, + *, + metrics: list[str], + period: str, + group_by: str | None = None, + start_date: str | None = None, + end_date: str | None = None, + ) -> dict[str, Any]: + self.calls.append( + { + "metrics": metrics, + "period": period, + "group_by": group_by, + "start_date": start_date, + "end_date": end_date, + } + ) + if self._error: + raise self._error + return self._result + + +class MockMetricSink: + """Mock metric sink for testing.""" + + def __init__( + self, + *, + error: Exception | None = None, + ) -> None: + self._error = error + self.recorded: list[dict[str, Any]] = [] + + async def record( + self, + *, + name: str, + value: float, + tags: dict[str, str] | None = None, + unit: str | None = None, + ) -> None: + self.recorded.append( + { + "name": name, + "value": value, + "tags": tags, + "unit": unit, + } + ) + if self._error: + raise self._error + + +@pytest.fixture +def default_config() -> AnalyticsToolsConfig: + return AnalyticsToolsConfig() + + +@pytest.fixture +def restricted_config() -> AnalyticsToolsConfig: + return AnalyticsToolsConfig(allowed_metrics=frozenset({"total_cost", "task_count"})) + + +@pytest.fixture +def mock_provider() -> MockAnalyticsProvider: + return MockAnalyticsProvider() + + +@pytest.fixture +def failing_provider() -> MockAnalyticsProvider: + return MockAnalyticsProvider(error=RuntimeError("query failed")) + + +@pytest.fixture +def mock_sink() -> MockMetricSink: + return MockMetricSink() + + +@pytest.fixture +def failing_sink() -> MockMetricSink: + return MockMetricSink(error=RuntimeError("sink error")) diff --git a/tests/unit/tools/analytics/test_config.py b/tests/unit/tools/analytics/test_config.py new file mode 100644 index 0000000000..5c635d4eda --- /dev/null +++ b/tests/unit/tools/analytics/test_config.py @@ -0,0 +1,61 @@ +"""Tests for analytics tool configuration models.""" + +from typing import Any + +import pytest +from pydantic import ValidationError + +from synthorg.tools.analytics.config import AnalyticsToolsConfig + + +@pytest.mark.unit +class TestAnalyticsToolsConfig: + """Tests for AnalyticsToolsConfig.""" + + def test_default_values(self) -> None: + config = AnalyticsToolsConfig() + assert config.query_timeout == 60.0 + assert config.max_rows == 10_000 + assert config.allowed_metrics is None + + def test_frozen(self) -> None: + config = AnalyticsToolsConfig() + with pytest.raises(ValidationError): + config.query_timeout = 30.0 # type: ignore[misc] + + def test_custom_values(self) -> None: + config = AnalyticsToolsConfig( + query_timeout=120.0, + max_rows=5000, + allowed_metrics=frozenset({"total_cost", "task_count"}), + ) + assert config.query_timeout == 120.0 + assert config.max_rows == 5000 + assert config.allowed_metrics == frozenset({"total_cost", "task_count"}) + + @pytest.mark.parametrize( + "kwargs", + [ + {"query_timeout": 0}, + {"query_timeout": 301.0}, + {"max_rows": 0}, + {"max_rows": 100_001}, + {"query_timeout": float("nan")}, + {"query_timeout": float("inf")}, + ], + ids=[ + "timeout_zero", + "timeout_above_max", + "rows_zero", + "rows_above_max", + "timeout_nan", + "timeout_inf", + ], + ) + def test_rejects_invalid_params(self, kwargs: dict[str, Any]) -> None: + with pytest.raises(ValidationError): + AnalyticsToolsConfig(**kwargs) + + def test_blank_metric_name_rejected(self) -> None: + with pytest.raises(ValidationError): + AnalyticsToolsConfig(allowed_metrics=frozenset({"valid", " "})) diff --git a/tests/unit/tools/analytics/test_data_aggregator.py b/tests/unit/tools/analytics/test_data_aggregator.py new file mode 100644 index 0000000000..63eb3a66c1 --- /dev/null +++ b/tests/unit/tools/analytics/test_data_aggregator.py @@ -0,0 +1,207 @@ +"""Tests for the data aggregator tool.""" + +import pytest + +from synthorg.core.enums import ActionType, ToolCategory +from synthorg.tools.analytics.config import AnalyticsToolsConfig +from synthorg.tools.analytics.data_aggregator import ( + AnalyticsProvider, + DataAggregatorTool, +) + +from .conftest import MockAnalyticsProvider + + +@pytest.mark.unit +class TestDataAggregatorTool: + """Tests for DataAggregatorTool.""" + + @pytest.mark.parametrize( + ("attr", "expected"), + [ + ("category", ToolCategory.ANALYTICS), + ("action_type", ActionType.CODE_READ), + ("name", "data_aggregator"), + ], + ids=["category", "action_type", "name"], + ) + def test_tool_attributes( + self, + mock_provider: MockAnalyticsProvider, + attr: str, + expected: object, + ) -> None: + tool = DataAggregatorTool(provider=mock_provider) + assert getattr(tool, attr) == expected + + async def test_execute_no_provider_returns_error(self) -> None: + tool = DataAggregatorTool(provider=None) + result = await tool.execute( + arguments={ + "metrics": ["total_cost"], + "period": "7d", + } + ) + assert result.is_error + assert "No AnalyticsProvider" in result.content + + async def test_execute_success( + self, + mock_provider: MockAnalyticsProvider, + ) -> None: + tool = DataAggregatorTool(provider=mock_provider) + result = await tool.execute( + arguments={ + "metrics": ["total_cost", "task_count"], + "period": "7d", + } + ) + assert not result.is_error + assert "total_cost" in result.content + assert result.metadata["total_cost"] == 1234.56 + assert len(mock_provider.calls) == 1 + + async def test_execute_passes_all_params( + self, + mock_provider: MockAnalyticsProvider, + ) -> None: + tool = DataAggregatorTool(provider=mock_provider) + await tool.execute( + arguments={ + "metrics": ["total_cost"], + "period": "custom", + "group_by": "day", + "start_date": "2026-01-01", + "end_date": "2026-01-31", + } + ) + call = mock_provider.calls[0] + assert call["metrics"] == ["total_cost"] + assert call["period"] == "custom" + assert call["group_by"] == "day" + assert call["start_date"] == "2026-01-01" + assert call["end_date"] == "2026-01-31" + + async def test_execute_invalid_period( + self, + mock_provider: MockAnalyticsProvider, + ) -> None: + tool = DataAggregatorTool(provider=mock_provider) + result = await tool.execute( + arguments={ + "metrics": ["total_cost"], + "period": "invalid", + } + ) + assert result.is_error + assert "Invalid period" in result.content + + async def test_execute_invalid_group_by( + self, + mock_provider: MockAnalyticsProvider, + ) -> None: + tool = DataAggregatorTool(provider=mock_provider) + result = await tool.execute( + arguments={ + "metrics": ["total_cost"], + "period": "7d", + "group_by": "invalid", + } + ) + assert result.is_error + assert "Invalid group_by" in result.content + + async def test_execute_provider_error( + self, + failing_provider: MockAnalyticsProvider, + ) -> None: + tool = DataAggregatorTool(provider=failing_provider) + result = await tool.execute( + arguments={ + "metrics": ["total_cost"], + "period": "7d", + } + ) + assert result.is_error + assert "Analytics query failed" in result.content + + async def test_execute_metric_whitelist( + self, + mock_provider: MockAnalyticsProvider, + restricted_config: AnalyticsToolsConfig, + ) -> None: + tool = DataAggregatorTool( + provider=mock_provider, + config=restricted_config, + ) + result = await tool.execute( + arguments={ + "metrics": ["total_cost", "secret_metric"], + "period": "7d", + } + ) + assert result.is_error + assert "not allowed" in result.content + + async def test_execute_metric_whitelist_allowed( + self, + mock_provider: MockAnalyticsProvider, + restricted_config: AnalyticsToolsConfig, + ) -> None: + tool = DataAggregatorTool( + provider=mock_provider, + config=restricted_config, + ) + result = await tool.execute( + arguments={ + "metrics": ["total_cost"], + "period": "7d", + } + ) + assert not result.is_error + + async def test_execute_custom_period_requires_dates( + self, + mock_provider: MockAnalyticsProvider, + ) -> None: + tool = DataAggregatorTool(provider=mock_provider) + result = await tool.execute( + arguments={ + "metrics": ["total_cost"], + "period": "custom", + } + ) + assert result.is_error + assert "start_date" in result.content + + async def test_execute_invalid_date_format( + self, + mock_provider: MockAnalyticsProvider, + ) -> None: + tool = DataAggregatorTool(provider=mock_provider) + result = await tool.execute( + arguments={ + "metrics": ["total_cost"], + "period": "custom", + "start_date": "not-a-date", + "end_date": "2026-01-31", + } + ) + assert result.is_error + assert "Invalid start_date" in result.content + + def test_mock_provider_satisfies_protocol( + self, + mock_provider: MockAnalyticsProvider, + ) -> None: + assert isinstance(mock_provider, AnalyticsProvider) + + def test_parameters_schema_requires_metrics_and_period( + self, + mock_provider: MockAnalyticsProvider, + ) -> None: + tool = DataAggregatorTool(provider=mock_provider) + schema = tool.parameters_schema + assert schema is not None + assert "metrics" in schema["required"] + assert "period" in schema["required"] diff --git a/tests/unit/tools/analytics/test_metric_collector.py b/tests/unit/tools/analytics/test_metric_collector.py new file mode 100644 index 0000000000..1eb1faed0f --- /dev/null +++ b/tests/unit/tools/analytics/test_metric_collector.py @@ -0,0 +1,165 @@ +"""Tests for the metric collector tool.""" + +import pytest + +from synthorg.core.enums import ToolCategory +from synthorg.tools.analytics.config import AnalyticsToolsConfig +from synthorg.tools.analytics.metric_collector import ( + MetricCollectorTool, + MetricSink, +) + +from .conftest import MockMetricSink + + +@pytest.mark.unit +class TestMetricCollectorTool: + """Tests for MetricCollectorTool.""" + + @pytest.mark.parametrize( + ("attr", "expected"), + [ + ("category", ToolCategory.ANALYTICS), + ("action_type", "metrics:record"), + ("name", "metric_collector"), + ], + ids=["category", "action_type", "name"], + ) + def test_tool_attributes( + self, mock_sink: MockMetricSink, attr: str, expected: object + ) -> None: + tool = MetricCollectorTool(sink=mock_sink) + assert getattr(tool, attr) == expected + + async def test_execute_no_sink_returns_error(self) -> None: + tool = MetricCollectorTool(sink=None) + result = await tool.execute( + arguments={ + "metric_name": "test", + "value": 1.0, + } + ) + assert result.is_error + assert "No MetricSink" in result.content + + async def test_execute_success( + self, + mock_sink: MockMetricSink, + ) -> None: + tool = MetricCollectorTool(sink=mock_sink) + result = await tool.execute( + arguments={ + "metric_name": "response_time", + "value": 1.23, + "unit": "seconds", + } + ) + assert not result.is_error + assert "response_time" in result.content + assert "1.23" in result.content + assert "seconds" in result.content + assert len(mock_sink.recorded) == 1 + recorded = mock_sink.recorded[0] + assert recorded["name"] == "response_time" + assert recorded["value"] == 1.23 + + async def test_execute_with_tags( + self, + mock_sink: MockMetricSink, + ) -> None: + tool = MetricCollectorTool(sink=mock_sink) + result = await tool.execute( + arguments={ + "metric_name": "request_count", + "value": 42, + "tags": {"endpoint": "/api/tasks"}, + } + ) + assert not result.is_error + assert len(mock_sink.recorded) == 1 + recorded = mock_sink.recorded[0] + assert recorded["tags"]["endpoint"] == "/api/tasks" + + async def test_execute_metric_not_allowed( + self, + mock_sink: MockMetricSink, + restricted_config: AnalyticsToolsConfig, + ) -> None: + tool = MetricCollectorTool( + sink=mock_sink, + config=restricted_config, + ) + result = await tool.execute( + arguments={ + "metric_name": "secret_metric", + "value": 1.0, + } + ) + assert result.is_error + assert "not allowed" in result.content + + async def test_execute_metric_allowed( + self, + mock_sink: MockMetricSink, + restricted_config: AnalyticsToolsConfig, + ) -> None: + tool = MetricCollectorTool( + sink=mock_sink, + config=restricted_config, + ) + result = await tool.execute( + arguments={ + "metric_name": "total_cost", + "value": 100.0, + } + ) + assert not result.is_error + + async def test_execute_sink_error( + self, + failing_sink: MockMetricSink, + ) -> None: + tool = MetricCollectorTool(sink=failing_sink) + result = await tool.execute( + arguments={ + "metric_name": "test", + "value": 1.0, + } + ) + assert result.is_error + assert "Metric recording failed" in result.content + + async def test_execute_returns_metadata( + self, + mock_sink: MockMetricSink, + ) -> None: + tool = MetricCollectorTool(sink=mock_sink) + result = await tool.execute( + arguments={ + "metric_name": "cpu_usage", + "value": 85.5, + "unit": "percent", + "tags": {"host": "worker-1"}, + } + ) + assert not result.is_error + assert result.metadata["metric_name"] == "cpu_usage" + assert result.metadata["value"] == 85.5 + assert result.metadata["unit"] == "percent" + assert result.metadata["tags"]["host"] == "worker-1" + + def test_mock_sink_satisfies_protocol( + self, + mock_sink: MockMetricSink, + ) -> None: + assert isinstance(mock_sink, MetricSink) + + def test_parameters_schema_requires_name_and_value( + self, + mock_sink: MockMetricSink, + ) -> None: + tool = MetricCollectorTool(sink=mock_sink) + schema = tool.parameters_schema + assert schema is not None + assert "metric_name" in schema["required"] + assert "value" in schema["required"] diff --git a/tests/unit/tools/analytics/test_report_generator.py b/tests/unit/tools/analytics/test_report_generator.py new file mode 100644 index 0000000000..4b4bb17abc --- /dev/null +++ b/tests/unit/tools/analytics/test_report_generator.py @@ -0,0 +1,159 @@ +"""Tests for the report generator tool.""" + +import json + +import pytest + +from synthorg.core.enums import ActionType, ToolCategory +from synthorg.tools.analytics.report_generator import ReportGeneratorTool + +from .conftest import MockAnalyticsProvider + + +@pytest.mark.unit +class TestReportGeneratorTool: + """Tests for ReportGeneratorTool.""" + + @pytest.mark.parametrize( + ("attr", "expected"), + [ + ("category", ToolCategory.ANALYTICS), + ("action_type", ActionType.CODE_READ), + ("name", "report_generator"), + ], + ids=["category", "action_type", "name"], + ) + def test_tool_attributes( + self, mock_provider: MockAnalyticsProvider, attr: str, expected: object + ) -> None: + tool = ReportGeneratorTool(provider=mock_provider) + assert getattr(tool, attr) == expected + + async def test_execute_no_provider_returns_error(self) -> None: + tool = ReportGeneratorTool(provider=None) + result = await tool.execute( + arguments={ + "report_type": "budget_summary", + "period": "30d", + } + ) + assert result.is_error + assert "No AnalyticsProvider" in result.content + + async def test_execute_markdown_report( + self, + mock_provider: MockAnalyticsProvider, + ) -> None: + tool = ReportGeneratorTool(provider=mock_provider) + result = await tool.execute( + arguments={ + "report_type": "budget_summary", + "period": "30d", + "format": "markdown", + } + ) + assert not result.is_error + assert "# Budget Summary Report" in result.content + assert "**Period:** 30d" in result.content + assert result.metadata["format"] == "markdown" + + async def test_execute_text_report( + self, + mock_provider: MockAnalyticsProvider, + ) -> None: + tool = ReportGeneratorTool(provider=mock_provider) + result = await tool.execute( + arguments={ + "report_type": "performance", + "period": "7d", + "format": "text", + } + ) + assert not result.is_error + assert "Performance Report" in result.content + assert "Period: 7d" in result.content + + async def test_execute_json_report( + self, + mock_provider: MockAnalyticsProvider, + ) -> None: + tool = ReportGeneratorTool(provider=mock_provider) + result = await tool.execute( + arguments={ + "report_type": "cost_breakdown", + "period": "90d", + "format": "json", + } + ) + assert not result.is_error + parsed = json.loads(result.content) + assert parsed["report_type"] == "cost_breakdown" + assert parsed["period"] == "90d" + assert "data" in parsed + + async def test_execute_default_format_is_markdown( + self, + mock_provider: MockAnalyticsProvider, + ) -> None: + tool = ReportGeneratorTool(provider=mock_provider) + result = await tool.execute( + arguments={ + "report_type": "trend_analysis", + "period": "7d", + } + ) + assert not result.is_error + assert result.metadata["format"] == "markdown" + + async def test_execute_invalid_report_type( + self, + mock_provider: MockAnalyticsProvider, + ) -> None: + tool = ReportGeneratorTool(provider=mock_provider) + result = await tool.execute( + arguments={ + "report_type": "invalid", + "period": "7d", + } + ) + assert result.is_error + assert "Invalid report_type" in result.content + + async def test_execute_invalid_format( + self, + mock_provider: MockAnalyticsProvider, + ) -> None: + tool = ReportGeneratorTool(provider=mock_provider) + result = await tool.execute( + arguments={ + "report_type": "budget_summary", + "period": "7d", + "format": "csv", + } + ) + assert result.is_error + assert "Invalid format" in result.content + + async def test_execute_provider_error( + self, + failing_provider: MockAnalyticsProvider, + ) -> None: + tool = ReportGeneratorTool(provider=failing_provider) + result = await tool.execute( + arguments={ + "report_type": "budget_summary", + "period": "7d", + } + ) + assert result.is_error + assert "query failed" in result.content + + def test_parameters_schema_required_fields( + self, + mock_provider: MockAnalyticsProvider, + ) -> None: + tool = ReportGeneratorTool(provider=mock_provider) + schema = tool.parameters_schema + assert schema is not None + assert "report_type" in schema["required"] + assert "period" in schema["required"] diff --git a/tests/unit/tools/communication/__init__.py b/tests/unit/tools/communication/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/tools/communication/conftest.py b/tests/unit/tools/communication/conftest.py new file mode 100644 index 0000000000..1ba20e8615 --- /dev/null +++ b/tests/unit/tools/communication/conftest.py @@ -0,0 +1,55 @@ +"""Shared fixtures for communication tool tests.""" + +import pytest + +from synthorg.notifications.models import Notification +from synthorg.tools.communication.config import ( + CommunicationToolsConfig, + EmailConfig, +) + + +class MockNotificationDispatcher: + """Mock notification dispatcher for testing.""" + + def __init__( + self, + *, + error: Exception | None = None, + ) -> None: + self._error = error + self.dispatched: list[Notification] = [] + + async def dispatch(self, notification: Notification) -> None: + if self._error: + raise self._error + self.dispatched.append(notification) + + +@pytest.fixture +def email_config() -> EmailConfig: + return EmailConfig( + host="smtp.example.com", + port=587, + from_address="test@example.com", + ) + + +@pytest.fixture +def comm_config(email_config: EmailConfig) -> CommunicationToolsConfig: + return CommunicationToolsConfig(email=email_config) + + +@pytest.fixture +def comm_config_no_email() -> CommunicationToolsConfig: + return CommunicationToolsConfig() + + +@pytest.fixture +def mock_dispatcher() -> MockNotificationDispatcher: + return MockNotificationDispatcher() + + +@pytest.fixture +def failing_dispatcher() -> MockNotificationDispatcher: + return MockNotificationDispatcher(error=RuntimeError("dispatch failed")) diff --git a/tests/unit/tools/communication/test_config.py b/tests/unit/tools/communication/test_config.py new file mode 100644 index 0000000000..e24788936d --- /dev/null +++ b/tests/unit/tools/communication/test_config.py @@ -0,0 +1,154 @@ +"""Tests for communication tool configuration models.""" + +from typing import Any + +import pytest +from pydantic import ValidationError + +from synthorg.tools.communication.config import ( + CommunicationToolsConfig, + EmailConfig, +) + + +@pytest.mark.unit +class TestEmailConfig: + """Tests for EmailConfig.""" + + def test_required_fields(self) -> None: + config = EmailConfig( + host="smtp.example.com", + from_address="test@example.com", + ) + assert config.host == "smtp.example.com" + assert config.port == 587 + assert config.from_address == "test@example.com" + assert config.use_tls is True + assert config.username is None + assert config.password is None + + def test_frozen(self) -> None: + config = EmailConfig( + host="smtp.example.com", + from_address="test@example.com", + ) + with pytest.raises(ValidationError): + config.host = "other" # type: ignore[misc] + + @pytest.mark.parametrize("port", [0, 70000], ids=["too_low", "too_high"]) + def test_invalid_port(self, port: int) -> None: + with pytest.raises(ValidationError): + EmailConfig( + host="smtp.example.com", + from_address="test@example.com", + port=port, + ) + + def test_blank_host_rejected(self) -> None: + with pytest.raises(ValidationError): + EmailConfig(host=" ", from_address="test@example.com") + + def test_password_not_in_repr(self) -> None: + config = EmailConfig( + host="smtp.example.com", + from_address="test@example.com", + username="user", + password="secret", + ) + assert "secret" not in repr(config) + + @pytest.mark.parametrize( + "kwargs", + [ + {"username": "user"}, + {"password": "secret"}, + ], + ids=["username_only", "password_only"], + ) + def test_partial_credentials_rejected(self, kwargs: dict[str, Any]) -> None: + with pytest.raises(ValidationError, match="username and password"): + EmailConfig( + host="smtp.example.com", + from_address="test@example.com", + **kwargs, + ) + + def test_both_credentials_accepted(self) -> None: + config = EmailConfig( + host="smtp.example.com", + from_address="test@example.com", + username="user", + password="secret", + ) + assert config.username == "user" + + def test_no_credentials_accepted(self) -> None: + config = EmailConfig( + host="smtp.example.com", + from_address="test@example.com", + ) + assert config.username is None + assert config.password is None + + def test_tls_mutual_exclusivity(self) -> None: + with pytest.raises(ValidationError, match="mutually exclusive"): + EmailConfig( + host="smtp.example.com", + from_address="test@example.com", + use_tls=True, + use_implicit_tls=True, + ) + + def test_smtp_timeout_valid(self) -> None: + config = EmailConfig( + host="smtp.example.com", + from_address="test@example.com", + smtp_timeout=30.0, + ) + assert config.smtp_timeout == 30.0 + + @pytest.mark.parametrize( + "timeout", + [0, -1.0, 121.0, float("nan"), float("inf")], + ids=["zero", "negative", "above_max", "nan", "inf"], + ) + def test_smtp_timeout_invalid(self, timeout: float) -> None: + with pytest.raises(ValidationError): + EmailConfig( + host="smtp.example.com", + from_address="test@example.com", + smtp_timeout=timeout, + ) + + +@pytest.mark.unit +class TestCommunicationToolsConfig: + """Tests for CommunicationToolsConfig.""" + + def test_default_values(self) -> None: + config = CommunicationToolsConfig() + assert config.email is None + assert config.max_recipients == 100 + + def test_frozen(self) -> None: + config = CommunicationToolsConfig() + with pytest.raises(ValidationError): + config.max_recipients = 50 # type: ignore[misc] + + def test_with_email(self) -> None: + email = EmailConfig( + host="smtp.example.com", + from_address="test@example.com", + ) + config = CommunicationToolsConfig(email=email) + assert config.email is not None + assert config.email.host == "smtp.example.com" + + @pytest.mark.parametrize( + "value", + [0, 1001, float("nan")], + ids=["zero", "above_max", "nan"], + ) + def test_invalid_max_recipients(self, value: Any) -> None: + with pytest.raises(ValidationError): + CommunicationToolsConfig(max_recipients=value) diff --git a/tests/unit/tools/communication/test_email_sender.py b/tests/unit/tools/communication/test_email_sender.py new file mode 100644 index 0000000000..d245906141 --- /dev/null +++ b/tests/unit/tools/communication/test_email_sender.py @@ -0,0 +1,159 @@ +"""Tests for the email sender tool.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from synthorg.core.enums import ActionType, ToolCategory +from synthorg.tools.communication.config import CommunicationToolsConfig, EmailConfig +from synthorg.tools.communication.email_sender import EmailSenderTool + + +@pytest.mark.unit +class TestEmailSenderTool: + """Tests for EmailSenderTool.""" + + @pytest.mark.parametrize( + ("attr", "expected"), + [ + ("category", ToolCategory.COMMUNICATION), + ("action_type", ActionType.COMMS_EXTERNAL), + ("name", "email_sender"), + ], + ids=["category", "action_type", "name"], + ) + def test_tool_attributes( + self, + comm_config: CommunicationToolsConfig, + attr: str, + expected: object, + ) -> None: + tool = EmailSenderTool(config=comm_config) + assert getattr(tool, attr) == expected + + async def test_execute_no_email_config_returns_error( + self, + comm_config_no_email: CommunicationToolsConfig, + ) -> None: + tool = EmailSenderTool(config=comm_config_no_email) + result = await tool.execute( + arguments={ + "to": ["user@example.com"], + "subject": "Test", + } + ) + assert result.is_error + assert "SMTP configuration" in result.content + + async def test_execute_empty_recipients_returns_error( + self, + comm_config: CommunicationToolsConfig, + ) -> None: + tool = EmailSenderTool(config=comm_config) + result = await tool.execute(arguments={"to": [], "subject": "Test"}) + assert result.is_error + assert "At least one recipient" in result.content + + async def test_execute_too_many_recipients(self) -> None: + config = CommunicationToolsConfig( + email=EmailConfig( + host="smtp.example.com", + from_address="test@example.com", + ), + max_recipients=2, + ) + tool = EmailSenderTool(config=config) + result = await tool.execute( + arguments={ + "to": ["a@ex.com", "b@ex.com", "c@ex.com"], + "subject": "Test", + } + ) + assert result.is_error + assert "Too many recipients" in result.content + + @patch.object(EmailSenderTool, "_send_sync") + async def test_execute_success( + self, + mock_send: MagicMock, + comm_config: CommunicationToolsConfig, + ) -> None: + tool = EmailSenderTool(config=comm_config) + result = await tool.execute( + arguments={ + "to": ["user@example.com"], + "subject": "Hello", + "body": "World", + } + ) + assert not result.is_error + assert "sent successfully" in result.content + assert result.metadata["to"] == ["user@example.com"] + mock_send.assert_called_once() + + @patch.object( + EmailSenderTool, + "_send_sync", + side_effect=RuntimeError("SMTP error"), + ) + async def test_execute_smtp_error( + self, + mock_send: MagicMock, + comm_config: CommunicationToolsConfig, + ) -> None: + tool = EmailSenderTool(config=comm_config) + result = await tool.execute( + arguments={ + "to": ["user@example.com"], + "subject": "Test", + } + ) + assert result.is_error + assert "Email sending failed" in result.content + mock_send.assert_called_once() + + @patch.object(EmailSenderTool, "_send_sync") + async def test_execute_with_cc_and_bcc( + self, + mock_send: MagicMock, + comm_config: CommunicationToolsConfig, + ) -> None: + tool = EmailSenderTool(config=comm_config) + result = await tool.execute( + arguments={ + "to": ["a@ex.com"], + "cc": ["b@ex.com"], + "bcc": ["c@ex.com"], + "subject": "Test", + } + ) + assert not result.is_error + assert "3 recipient(s)" in result.content + mock_send.assert_called_once() + + @patch.object(EmailSenderTool, "_send_sync") + async def test_execute_rejects_newline_in_address( + self, + mock_send: MagicMock, + comm_config: CommunicationToolsConfig, + ) -> None: + tool = EmailSenderTool(config=comm_config) + result = await tool.execute( + arguments={ + "to": ["attacker@ex.com\nBcc: victim@ex.com"], + "subject": "Test", + } + ) + assert result.is_error + assert "invalid characters" in result.content + mock_send.assert_not_called() + + def test_parameters_schema_requires_to_and_subject( + self, + comm_config: CommunicationToolsConfig, + ) -> None: + tool = EmailSenderTool(config=comm_config) + schema = tool.parameters_schema + assert schema is not None + assert "to" in schema["required"] + assert "subject" in schema["required"] diff --git a/tests/unit/tools/communication/test_notification_sender.py b/tests/unit/tools/communication/test_notification_sender.py new file mode 100644 index 0000000000..659bee0e63 --- /dev/null +++ b/tests/unit/tools/communication/test_notification_sender.py @@ -0,0 +1,146 @@ +"""Tests for the notification sender tool.""" + +import pytest + +from synthorg.core.enums import ActionType, ToolCategory +from synthorg.tools.communication.notification_sender import ( + NotificationSenderTool, +) + +from .conftest import MockNotificationDispatcher + + +@pytest.mark.unit +class TestNotificationSenderTool: + """Tests for NotificationSenderTool.""" + + @pytest.mark.parametrize( + ("attr", "expected"), + [ + ("category", ToolCategory.COMMUNICATION), + ("action_type", ActionType.COMMS_INTERNAL), + ("name", "notification_sender"), + ], + ids=["category", "action_type", "name"], + ) + def test_tool_attributes( + self, + mock_dispatcher: MockNotificationDispatcher, + attr: str, + expected: object, + ) -> None: + tool = NotificationSenderTool(dispatcher=mock_dispatcher) + assert getattr(tool, attr) == expected + + async def test_execute_no_dispatcher_returns_error(self) -> None: + tool = NotificationSenderTool(dispatcher=None) + result = await tool.execute( + arguments={ + "category": "system", + "severity": "info", + "title": "Test", + "source": "test-agent", + } + ) + assert result.is_error + assert "NotificationDispatcher" in result.content + + async def test_execute_success( + self, + mock_dispatcher: MockNotificationDispatcher, + ) -> None: + tool = NotificationSenderTool(dispatcher=mock_dispatcher) + result = await tool.execute( + arguments={ + "category": "system", + "severity": "info", + "title": "Deployment complete", + "source": "deploy-agent", + "body": "All services healthy.", + } + ) + assert not result.is_error + assert "Deployment complete" in result.content + assert len(mock_dispatcher.dispatched) == 1 + notif = mock_dispatcher.dispatched[0] + assert notif.title == "Deployment complete" + assert notif.source == "deploy-agent" + + @pytest.mark.parametrize( + ("args", "expected_msg"), + [ + ( + { + "category": "invalid", + "severity": "info", + "title": "Test", + "source": "test", + }, + "Invalid category", + ), + ( + { + "category": "system", + "severity": "invalid", + "title": "Test", + "source": "test", + }, + "Invalid severity", + ), + ], + ids=["invalid_category", "invalid_severity"], + ) + async def test_execute_invalid_enum( + self, + mock_dispatcher: MockNotificationDispatcher, + args: dict[str, str], + expected_msg: str, + ) -> None: + tool = NotificationSenderTool(dispatcher=mock_dispatcher) + result = await tool.execute(arguments=args) + assert result.is_error + assert expected_msg in result.content + + async def test_execute_dispatch_error( + self, + failing_dispatcher: MockNotificationDispatcher, + ) -> None: + tool = NotificationSenderTool(dispatcher=failing_dispatcher) + result = await tool.execute( + arguments={ + "category": "system", + "severity": "error", + "title": "Alert", + "source": "test", + } + ) + assert result.is_error + assert "dispatch failed" in result.content + + async def test_execute_returns_metadata( + self, + mock_dispatcher: MockNotificationDispatcher, + ) -> None: + tool = NotificationSenderTool(dispatcher=mock_dispatcher) + result = await tool.execute( + arguments={ + "category": "budget", + "severity": "warning", + "title": "Budget threshold", + "source": "budget-enforcer", + } + ) + assert not result.is_error + assert result.metadata["category"] == "budget" + assert result.metadata["severity"] == "warning" + assert "notification_id" in result.metadata + + def test_parameters_schema_required_fields( + self, + mock_dispatcher: MockNotificationDispatcher, + ) -> None: + tool = NotificationSenderTool(dispatcher=mock_dispatcher) + schema = tool.parameters_schema + assert schema is not None + for field in ("category", "severity", "title", "source"): + assert field in schema["required"] diff --git a/tests/unit/tools/communication/test_template_formatter.py b/tests/unit/tools/communication/test_template_formatter.py new file mode 100644 index 0000000000..dea36f4b5a --- /dev/null +++ b/tests/unit/tools/communication/test_template_formatter.py @@ -0,0 +1,210 @@ +"""Tests for the template formatter tool.""" + +from typing import Any + +import pytest + +from synthorg.core.enums import ActionType, ToolCategory +from synthorg.tools.communication.template_formatter import ( + TemplateFormatterTool, +) + + +@pytest.mark.unit +class TestTemplateFormatterTool: + """Tests for TemplateFormatterTool.""" + + @pytest.mark.parametrize( + ("attr", "expected"), + [ + ("category", ToolCategory.COMMUNICATION), + ("action_type", ActionType.CODE_READ), + ("name", "template_formatter"), + ], + ids=["category", "action_type", "name"], + ) + def test_tool_attributes(self, attr: str, expected: object) -> None: + tool = TemplateFormatterTool() + assert getattr(tool, attr) == expected + + async def test_execute_simple_template(self) -> None: + tool = TemplateFormatterTool() + result = await tool.execute( + arguments={ + "template": "Hello {{ name }}!", + "variables": {"name": "Alice"}, + } + ) + assert not result.is_error + assert result.content == "Hello Alice!" + + async def test_execute_multiple_variables(self) -> None: + tool = TemplateFormatterTool() + result = await tool.execute( + arguments={ + "template": "{{ greeting }} {{ name }}, balance: {{ amount }}", + "variables": { + "greeting": "Hi", + "name": "Bob", + "amount": "$100", + }, + } + ) + assert not result.is_error + assert result.content == "Hi Bob, balance: $100" + + async def test_execute_invalid_template_syntax(self) -> None: + tool = TemplateFormatterTool() + result = await tool.execute( + arguments={ + "template": "Hello {{ name", + "variables": {"name": "test"}, + } + ) + assert result.is_error + assert "Invalid template syntax" in result.content + + async def test_execute_undefined_variable(self) -> None: + tool = TemplateFormatterTool() + result = await tool.execute( + arguments={ + "template": "Hello {{ name }}!", + "variables": {}, + } + ) + # Jinja2 renders undefined as empty string by default + assert not result.is_error + assert result.content == "Hello !" + + async def test_execute_with_format_metadata(self) -> None: + tool = TemplateFormatterTool() + result = await tool.execute( + arguments={ + "template": "# {{ title }}", + "variables": {"title": "Report"}, + "format": "markdown", + } + ) + assert not result.is_error + assert result.metadata["format"] == "markdown" + assert result.metadata["output_length"] == len("# Report") + + async def test_execute_invalid_format(self) -> None: + tool = TemplateFormatterTool() + result = await tool.execute( + arguments={ + "template": "test", + "variables": {}, + "format": "yaml", + } + ) + assert result.is_error + assert "Invalid format" in result.content + + @pytest.mark.parametrize( + ("args", "expected_msg"), + [ + ( + {"template": 123, "variables": {}}, + "'template' must be a string", + ), + ( + {"template": "hi", "variables": "notadict"}, + "'variables' must be a dict", + ), + ( + {"template": "hi", "variables": {}, "format": 123}, + "'format' must be a string", + ), + ], + ids=["template_not_str", "variables_not_dict", "format_not_str"], + ) + async def test_execute_rejects_invalid_arg_types( + self, args: dict[str, Any], expected_msg: str + ) -> None: + tool = TemplateFormatterTool() + result = await tool.execute(arguments=args) + assert result.is_error + assert expected_msg in result.content + + async def test_execute_html_template(self) -> None: + tool = TemplateFormatterTool() + result = await tool.execute( + arguments={ + "template": "

{{ title }}

{{ body }}

", + "variables": {"title": "Hello", "body": "World"}, + "format": "html", + } + ) + assert not result.is_error + assert "

Hello

" in result.content + + async def test_jinja2_conditionals(self) -> None: + tool = TemplateFormatterTool() + result = await tool.execute( + arguments={ + "template": "{% if urgent %}URGENT: {% endif %}{{ msg }}", + "variables": {"urgent": True, "msg": "Server down"}, + } + ) + assert not result.is_error + assert result.content == "URGENT: Server down" + + async def test_jinja2_loop(self) -> None: + tool = TemplateFormatterTool() + result = await tool.execute( + arguments={ + "template": "{% for item in items %}{{ item }} {% endfor %}", + "variables": {"items": ["a", "b", "c"]}, + } + ) + assert not result.is_error + assert "a b c" in result.content + + async def test_sandbox_blocks_attribute_access(self) -> None: + """SandboxedEnvironment prevents dangerous attribute access.""" + tool = TemplateFormatterTool() + result = await tool.execute( + arguments={ + "template": "{{ ''.__class__.__bases__ }}", + "variables": {}, + } + ) + # Sandbox must block dangerous attribute access + assert result.is_error + + async def test_html_format_escapes_xss(self) -> None: + """HTML format auto-escapes to prevent XSS.""" + tool = TemplateFormatterTool() + result = await tool.execute( + arguments={ + "template": "

{{ content }}

", + "variables": {"content": ""}, + "format": "html", + } + ) + assert not result.is_error + assert "