diff --git a/custom_components/pricehawk/__init__.py b/custom_components/pricehawk/__init__.py index 678a71d..c963e90 100644 --- a/custom_components/pricehawk/__init__.py +++ b/custom_components/pricehawk/__init__.py @@ -41,6 +41,12 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: # Schedule periodic state persistence coordinator.schedule_persist() + # Phase 3.1 — schedule daily multi-plan ranking job at 00:30 local. + # First run also fires immediately so the alternatives sensor isn't + # empty until midnight on a fresh install. + coordinator.schedule_daily_ranking() + hass.async_create_task(coordinator.async_run_ranking_job()) + # Copy www assets (icon + HTML) and register sidebar panel await copy_www_assets(hass) await setup_panel_iframe(hass, entry) @@ -179,6 +185,34 @@ async def handle_backfill(call: object) -> None: hass.services.async_register(DOMAIN, "backfill_history", handle_backfill) + # Phase 3.1 commit 5 — manual ranking trigger. Lets users force-run + # the ranking pipeline from Developer Tools → Services without + # waiting for the next 00:30 schedule fire. Most useful right after + # switching plans (so the alternatives ranking reflects the new + # distributor / postcode immediately). + async def handle_rank_alternatives(call: object) -> None: + # CR-fix: malformed service payload (e.g. ``top_k: "abc"`` from + # a typo in a YAML automation) would raise ValueError/TypeError + # and fail the call. Coerce defensively + fall back to default. + raw = call.data.get("top_k", 20) # type: ignore[attr-defined] + try: + top_k = int(raw) + except (TypeError, ValueError): + _LOGGER.warning( + "rank_alternatives: invalid top_k=%r, using default 20", raw + ) + top_k = 20 + top_k = max(1, min(top_k, 100)) + result = await coordinator.async_run_ranking_job(top_k=top_k) + _LOGGER.info( + "rank_alternatives service: ran successfully, %d result(s)", + len(result), + ) + + hass.services.async_register( + DOMAIN, "rank_alternatives", handle_rank_alternatives + ) + _LOGGER.info("PriceHawk integration setup complete") return True @@ -191,6 +225,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ) if coordinator: coordinator.cancel_persist() + coordinator.cancel_ranking() await coordinator.async_persist_state() await remove_panel(hass) @@ -199,5 +234,6 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: if not hass.data.get(DOMAIN): hass.services.async_remove(DOMAIN, "analyze_csv") hass.services.async_remove(DOMAIN, "backfill_history") + hass.services.async_remove(DOMAIN, "rank_alternatives") return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) diff --git a/custom_components/pricehawk/cdr/ranking_job.py b/custom_components/pricehawk/cdr/ranking_job.py new file mode 100644 index 0000000..0570eac --- /dev/null +++ b/custom_components/pricehawk/cdr/ranking_job.py @@ -0,0 +1,169 @@ +"""Phase 3.1 — Coordinator-facing daily ranking job orchestration. + +Lives outside ``coordinator.py`` so the pure logic is unit-testable +without the HA app context that ``PriceHawkCoordinator`` requires +(``DataUpdateCoordinator[T]`` parameterised bases don't survive the +mock-based conftest). + +Each function takes the inputs it needs explicitly — config-entry +options, the registry, a HTTP session. The coordinator-side wrapper +methods (in ``coordinator.py``) own the side effects (scheduling +callbacks, persisting results, swallowing exceptions across the daily +boundary). +""" +from __future__ import annotations + +import logging +from typing import Any, TYPE_CHECKING + +from .ranking import DEFAULT_TOP_K, rank_alternatives +from .registry import RetailerEndpoint, find_by_brand, get_registry + +if TYPE_CHECKING: + import aiohttp + +_LOGGER = logging.getLogger(__name__) + +# Big-4 nationally-active retailers scanned on every daily run. +# EME refdata2 doesn't carry per-retailer geography, so we always +# attempt these four; ``rank_alternatives`` then filters their plans +# by the user's postcode/distributor anyway. +DEFAULT_COMPETITOR_BRAND_FRAGMENTS: tuple[str, ...] = ( + "agl", + "origin", + "energyaustralia", + "red energy", +) + + +def get_user_geography( + options: dict[str, Any], +) -> tuple[str | None, str | None, str | None]: + """Pull ``(state, postcode, distributor)`` from a config_entry's options. + + - ``postcode``: ``cdr_postcode`` option (set by the wizard). + - ``distributor``: first entry in ``cdr_plan.data.geography.distributors``. + The user already accepted this plan so its distributor IS theirs. + - ``state``: returned as ``None`` — derived later in the registry + filter when needed. Postcode + distributor is more precise. + """ + postcode = options.get("cdr_postcode") or None + # CR-fix: every level guarded with isinstance — malformed payloads + # can ship ``cdr_plan`` as a string, ``data`` as a list, ``geography`` + # as None, etc. Without guards, ``.get()`` / ``.strip()`` raise + # AttributeError and abort the whole ranking run. + plan_data = _safe_plan_data(options) + geo = plan_data.get("geography") or {} + if not isinstance(geo, dict): + return None, postcode, None + distributors = geo.get("distributors") + distributor = ( + distributors[0] + if isinstance(distributors, list) + and distributors + and isinstance(distributors[0], str) + else None + ) + return None, postcode, distributor + + +def _safe_plan_data(options: dict[str, Any]) -> dict[str, Any]: + """Pull ``cdr_plan.data`` safely. Returns ``{}`` on any malformed shape. + + Tolerated malformations: ``cdr_plan`` missing / non-dict, ``data`` + missing / non-dict. Used by both ``get_user_geography`` and + ``get_competitor_retailers``. + """ + cdr_plan = options.get("cdr_plan") + if not isinstance(cdr_plan, dict): + return {} + plan_data = cdr_plan.get("data") + return plan_data if isinstance(plan_data, dict) else {} + + +async def get_competitor_retailers( + session: aiohttp.ClientSession, + options: dict[str, Any], + *, + competitor_fragments: tuple[str, ...] = DEFAULT_COMPETITOR_BRAND_FRAGMENTS, +) -> list[RetailerEndpoint]: + """Build the retailer list scanned during the daily ranking job. + + Composition (in priority order, dedup by ``brand_id``): + 1. User's CURRENT retailer (from ``cdr_plan.data.brand``). + 2. The hardcoded big-4 competitors. + + Falls back to baked-in registry via ``get_registry``'s own fallback + when live fetch fails. Returns ``[]`` if registry is empty (edge + case; baked-in always has 100+ entries). + """ + endpoints, source = await get_registry(session) + _LOGGER.debug( + "ranking: registry source=%s, %d retailers", source, len(endpoints) + ) + + out: list[RetailerEndpoint] = [] + seen_brand_ids: set[str] = set() + + plan_data = _safe_plan_data(options) + raw_brand = plan_data.get("brand") + # ``brand`` is sometimes shipped as None or non-string by retailers; + # only accept str to keep ``.strip()`` and ``find_by_brand`` safe. + current_brand = raw_brand.strip() if isinstance(raw_brand, str) else "" + if current_brand: + current = find_by_brand(endpoints, current_brand) + if current is not None: + out.append(current) + seen_brand_ids.add(current.brand_id) + + for fragment in competitor_fragments: + match = find_by_brand(endpoints, fragment) + if match is None or match.brand_id in seen_brand_ids: + continue + out.append(match) + seen_brand_ids.add(match.brand_id) + + return out + + +async def run_ranking_job( + session: aiohttp.ClientSession, + options: dict[str, Any], + *, + top_k: int = DEFAULT_TOP_K, + plan_cache: dict[str, dict[str, Any]] | None = None, + competitor_fragments: tuple[str, ...] = DEFAULT_COMPETITOR_BRAND_FRAGMENTS, +) -> list[dict[str, Any]]: + """Run the cheap-rank pipeline. Returns the top-K plans. + + Cheap-rank only for now. Deep-rank (consumption replay) joins in + Phase 3.2 when the universal HA-history backfill ships and we + have real per-slot consumption to rank against. + + Caller (coordinator) is responsible for: + - Scheduling (``async_track_time_change``). + - Persisting the returned list onto coordinator state. + - Catching exceptions across the daily boundary (this function + only catches its own — ``rank_alternatives``'s exception + isolation per retailer). + + Returns ``[]`` if no retailers resolved (e.g. registry empty). + """ + retailers = await get_competitor_retailers( + session, options, competitor_fragments=competitor_fragments + ) + if not retailers: + _LOGGER.info("ranking: no competitor retailers resolved; skipping") + return [] + + _state, postcode, distributor = get_user_geography(options) + + return await rank_alternatives( + session, + retailers, + state=_state, + postcode=postcode, + distributor=distributor, + top_k=top_k, + cache=plan_cache, + ) diff --git a/custom_components/pricehawk/coordinator.py b/custom_components/pricehawk/coordinator.py index d6c7926..2fcff96 100644 --- a/custom_components/pricehawk/coordinator.py +++ b/custom_components/pricehawk/coordinator.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from datetime import datetime, timedelta +from datetime import date, datetime, timedelta from typing import Any import aiohttp @@ -13,7 +13,7 @@ from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.core import CALLBACK_TYPE, HomeAssistant from homeassistant.helpers.aiohttp_client import async_get_clientsession -from homeassistant.helpers.event import async_call_later +from homeassistant.helpers.event import async_call_later, async_track_time_change from homeassistant.helpers.storage import Store from homeassistant.helpers.update_coordinator import DataUpdateCoordinator from homeassistant.util import dt as dt_util @@ -46,6 +46,8 @@ CONF_LOCALVOLTS_PARTNER_ID, LOCALVOLTS_API_POLL_INTERVAL, ) +from .cdr.ranking import DEFAULT_TOP_K +from .cdr.ranking_job import run_ranking_job from .explanation import build_explanation from .localvolts_api import aggregate_to_half_hour, fetch_recent_intervals from .providers.cdr_plan import CdrPlanProvider @@ -62,6 +64,13 @@ _MAX_RETRIES = 3 _RETRY_BASE_DELAY = 2 # seconds, doubles each attempt +# Daily ranking job runs at this local time. 00:30 is after midnight +# rollover so today's daily_cost_history is already final, and well +# before users' morning HA dashboards refresh. Competitor retailer +# list lives in ``cdr.ranking_job`` so it stays testable without HA. +_RANKING_RUN_HOUR = 0 +_RANKING_RUN_MINUTE = 30 + def _extract_peak_rate_c_inc_gst(cdr_plan: dict[str, Any] | None) -> float | None: """Phase 3.0e — pull PEAK rate from a CDR plan envelope. @@ -244,6 +253,28 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY) self._persist_unsub: CALLBACK_TYPE | None = None + # Phase 3.1 — multi-plan ranking. Top-K cheaper alternatives + # populated by the daily 00:30 ranking job; consumed by the + # ranked-alternatives sensor (Phase 3.1 commit 6) and HA service + # (commit 5). Deep-rank (consumption replay through evaluator) + # is deferred to Phase 3.2 when HA-history backfill exists — + # without enough recorded slots, deep-rank has no signal. + self._cheap_ranked_alternatives: list[dict[str, Any]] = [] + self._ranking_last_run_at: datetime | None = None + # Plan-detail cache reused across same-day runs so a manual + # rerun via the rank_alternatives service skips re-fetching + # plans already pulled by the morning scheduled run. Cache + # clears on the FIRST run of a new calendar day so overnight + # republished plans get refreshed. + self._ranking_plan_cache: dict[str, dict[str, Any]] = {} + self._ranking_cache_date: date | None = None + self._ranking_unsub: CALLBACK_TYPE | None = None + # CR-fix: scheduled callback + manual service trigger can both + # call async_run_ranking_job concurrently. A second concurrent + # entry would interleave _ranking_plan_cache mutations and + # duplicate every expensive CDR detail fetch. Lock serialises. + self._ranking_lock = asyncio.Lock() + # ------------------------------------------------------------------ # Amber REST API polling # ------------------------------------------------------------------ @@ -1157,6 +1188,88 @@ def cancel_persist(self) -> None: self._persist_unsub() self._persist_unsub = None + # ------------------------------------------------------------------ + # Phase 3.1 — daily multi-plan ranking job + # ------------------------------------------------------------------ + + def schedule_daily_ranking(self) -> None: + """Register the 00:30 local-time daily ranking job. + + Uses ``async_track_time_change`` so the callback fires regardless + of the integration's 30s update tick. Safe to call twice; the + second call replaces the first (no double-schedule). + """ + self.cancel_ranking() + + async def _ranking_callback(_now: datetime) -> None: + await self.async_run_ranking_job() + + self._ranking_unsub = async_track_time_change( + self.hass, + _ranking_callback, + hour=_RANKING_RUN_HOUR, + minute=_RANKING_RUN_MINUTE, + second=0, + ) + + def cancel_ranking(self) -> None: + """Cancel the scheduled daily ranking callback.""" + if self._ranking_unsub is not None: + self._ranking_unsub() + self._ranking_unsub = None + + async def async_run_ranking_job( + self, *, top_k: int = DEFAULT_TOP_K + ) -> list[dict[str, Any]]: + """Run the daily ranking pipeline. Returns the persisted top-K. + + Called from the scheduled callback at 00:30 local, and also from + the future ``pricehawk.rank_alternatives`` HA service (Phase 3.1 + commit 5) on user request. Idempotent: re-runs use the per-plan + cache so unchanged plans skip re-fetching. + + Thin wrapper around ``cdr.ranking_job.run_ranking_job``: this + method owns HA-side side effects (session, exception + swallowing, state persistence) while the pure logic stays + unit-testable without HA's app context. + """ + # Serialise to prevent overlapping runs (scheduled callback + + # manual service trigger). Second caller blocks briefly then + # returns freshly populated results from the cache. + async with self._ranking_lock: + # Date-rollover cache reset BEFORE the run, not after. + # Keeps same-day reruns warm; new local-day run starts + # from empty cache so overnight republished plans get + # fresh data. + today = dt_util.now().date() + if self._ranking_cache_date != today: + self._ranking_plan_cache.clear() + self._ranking_cache_date = today + + session = async_get_clientsession(self.hass) + try: + ranked = await run_ranking_job( + session, + dict(self.config_entry.options), + top_k=top_k, + plan_cache=self._ranking_plan_cache, + ) + except Exception: # noqa: BLE001 — daily job must not raise + _LOGGER.exception("ranking: pipeline raised; keeping prior results") + return self._cheap_ranked_alternatives + + if ranked: + # Only overwrite prior results when the run actually + # produced something. An empty list usually means "no + # retailers resolved" or "all retailers down" — both + # transient; better to keep yesterday's ranking. + self._cheap_ranked_alternatives = ranked + self._ranking_last_run_at = dt_util.now() + _LOGGER.info( + "ranking: persisted %d alternative(s)", len(ranked), + ) + return ranked or self._cheap_ranked_alternatives + # ------------------------------------------------------------------ # Options update / engine rebuild # ------------------------------------------------------------------ diff --git a/custom_components/pricehawk/services.yaml b/custom_components/pricehawk/services.yaml index 92ec789..e4df2b6 100644 --- a/custom_components/pricehawk/services.yaml +++ b/custom_components/pricehawk/services.yaml @@ -30,3 +30,24 @@ backfill_history: min: 1 max: 90 mode: slider + +rank_alternatives: + name: Rank Alternative Plans + description: >- + Run the cheap-rank pipeline against your current retailer + competitor + retailers (AGL, Origin, EnergyAustralia, Red Energy). Results stored on + the coordinator and exposed via the alternatives sensor (Phase 3.1 + commit 6). Cheap-rank only for now; deep-rank by HA consumption replay + arrives in Phase 3.2. + fields: + top_k: + name: Top K + description: Number of cheapest alternatives to keep (default 20) + required: false + default: 20 + example: 20 + selector: + number: + min: 1 + max: 100 + mode: slider diff --git a/tests/conftest.py b/tests/conftest.py index a95495e..eab4b59 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,7 @@ def __init__(self, *args, **kwargs): "homeassistant.core": _MockModule(), "homeassistant.exceptions": _MockModule(), "homeassistant.helpers": _MockModule(), + "homeassistant.helpers.aiohttp_client": _MockModule(), "homeassistant.helpers.event": _MockModule(), "homeassistant.helpers.storage": _MockModule(), "homeassistant.helpers.update_coordinator": _MockModule(), @@ -32,6 +33,7 @@ def __init__(self, *args, **kwargs): _mods["homeassistant"].util = _mods["homeassistant.util"] _mods["homeassistant"].config_entries = _mods["homeassistant.config_entries"] _mods["homeassistant"].core = _mods["homeassistant.core"] +_mods["homeassistant.helpers"].aiohttp_client = _mods["homeassistant.helpers.aiohttp_client"] _mods["homeassistant.helpers"].event = _mods["homeassistant.helpers.event"] _mods["homeassistant.helpers"].storage = _mods["homeassistant.helpers.storage"] _mods["homeassistant.helpers"].update_coordinator = _mods["homeassistant.helpers.update_coordinator"] diff --git a/tests/test_coordinator_ranking.py b/tests/test_coordinator_ranking.py new file mode 100644 index 0000000..f74c927 --- /dev/null +++ b/tests/test_coordinator_ranking.py @@ -0,0 +1,300 @@ +"""Tests for Phase 3.1 ranking-job orchestration. + +The pure logic lives in ``cdr.ranking_job`` so it's unit-testable +without HA's app context (``PriceHawkCoordinator`` itself is +unreachable in test because its ``DataUpdateCoordinator[T]`` base +gets mocked away by ``tests/conftest.py``). + +Coordinator-side wrappers (``schedule_daily_ranking``, +``cancel_ranking``, ``async_run_ranking_job``) are 1-line delegates +verified by integration. The ``_RANKING_RUN_HOUR / MINUTE`` constants +are smoke-tested via direct module import. +""" +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, patch + +from custom_components.pricehawk.cdr.ranking_job import ( + DEFAULT_COMPETITOR_BRAND_FRAGMENTS, + get_competitor_retailers, + get_user_geography, + run_ranking_job, +) +from custom_components.pricehawk.cdr.registry import RetailerEndpoint + + +def _retailer(name: str, brand_id: str = "1") -> RetailerEndpoint: + return RetailerEndpoint( + brand_id=brand_id, + brand_name=name, + base_uri=f"https://example/{name.lower()}", + cdr_brand=name.lower(), + ) + + +# --------------------------------------------------------------------------- +# Module-level constants +# --------------------------------------------------------------------------- + + +class TestModuleConstants: + def test_competitor_fragments_includes_big_4(self): + assert "agl" in DEFAULT_COMPETITOR_BRAND_FRAGMENTS + assert "origin" in DEFAULT_COMPETITOR_BRAND_FRAGMENTS + assert "energyaustralia" in DEFAULT_COMPETITOR_BRAND_FRAGMENTS + assert "red energy" in DEFAULT_COMPETITOR_BRAND_FRAGMENTS + + def test_ranking_run_time_constants_after_midnight_rollover(self): + """00:30 local lets midnight daily-cost rollover complete first.""" + from custom_components.pricehawk.coordinator import ( + _RANKING_RUN_HOUR, + _RANKING_RUN_MINUTE, + ) + assert _RANKING_RUN_HOUR == 0 + assert _RANKING_RUN_MINUTE == 30 + + +# --------------------------------------------------------------------------- +# Geography extraction +# --------------------------------------------------------------------------- + + +class TestGetUserGeography: + def test_extracts_postcode_and_distributor_from_cdr_plan(self): + opts = { + "cdr_postcode": "3104", + "cdr_plan": { + "data": {"geography": {"distributors": ["United Energy"]}} + }, + } + state, postcode, distributor = get_user_geography(opts) + assert state is None # derived later, not stored + assert postcode == "3104" + assert distributor == "United Energy" + + def test_missing_postcode_returns_none(self): + opts = {"cdr_plan": {"data": {"geography": {}}}} + _, postcode, _ = get_user_geography(opts) + assert postcode is None + + def test_missing_cdr_plan_returns_none_distributor(self): + opts = {"cdr_postcode": "3000"} + _, postcode, distributor = get_user_geography(opts) + assert postcode == "3000" + assert distributor is None + + def test_first_distributor_picked_when_plan_has_multiple(self): + opts = {"cdr_plan": {"data": {"geography": {"distributors": ["A", "B"]}}}} + _, _, distributor = get_user_geography(opts) + assert distributor == "A" + + def test_empty_distributors_list_returns_none(self): + opts = {"cdr_plan": {"data": {"geography": {"distributors": []}}}} + _, _, distributor = get_user_geography(opts) + assert distributor is None + + def test_empty_options_returns_all_none(self): + state, postcode, distributor = get_user_geography({}) + assert (state, postcode, distributor) == (None, None, None) + + def test_non_list_distributors_safely_skipped(self): + """CR-fix: malformed payload may ship distributors as string/dict. + ``"United Energy"[0]`` would silently become ``"U"`` and skew + the ranking filter. Type guard returns None instead.""" + for bad in ["United Energy", {"name": "United"}, 42]: + opts = {"cdr_plan": {"data": {"geography": {"distributors": bad}}}} + _, _, distributor = get_user_geography(opts) + assert distributor is None + + def test_non_dict_cdr_plan_safely_skipped(self): + """CR-fix: ``cdr_plan`` shipped as a string/list/int doesn't + raise — return None distributor instead of AttributeError.""" + for bad in ["not-a-dict", ["wrong", "shape"], 42, None]: + opts = {"cdr_plan": bad} + _, _, distributor = get_user_geography(opts) + assert distributor is None + + def test_non_dict_data_safely_skipped(self): + """``cdr_plan["data"]`` shipped as non-dict (string / list / + None) returns None distributor — no AttributeError on + ``.get("geography")``.""" + for bad in ["broken", [1, 2], 0, None]: + opts = {"cdr_plan": {"data": bad}} + _, _, distributor = get_user_geography(opts) + assert distributor is None + + def test_non_dict_geography_safely_skipped(self): + """``data["geography"]`` shipped as non-dict returns None + distributor — no AttributeError on ``.get("distributors")``.""" + for bad in ["str-geo", [1], 42, None]: + opts = {"cdr_plan": {"data": {"geography": bad}}} + _, _, distributor = get_user_geography(opts) + assert distributor is None + + def test_first_distributor_must_be_string(self): + """Distributor list with non-str first element returns None — + prevents accidentally passing a dict / int as the distributor + filter to rank_alternatives.""" + opts = {"cdr_plan": {"data": {"geography": { + "distributors": [{"name": "U"}, "AGL"], + }}}} + _, _, distributor = get_user_geography(opts) + assert distributor is None + + +# --------------------------------------------------------------------------- +# Competitor retailer composition +# --------------------------------------------------------------------------- + + +class TestGetCompetitorRetailers: + def test_includes_user_current_retailer_first(self): + opts = {"cdr_plan": {"data": {"brand": "GloBird"}}} + globird = _retailer("GloBird", brand_id="100") + agl = _retailer("AGL Energy", brand_id="200") + origin = _retailer("Origin", brand_id="300") + ea = _retailer("EnergyAustralia", brand_id="400") + red = _retailer("Red Energy", brand_id="500") + endpoints = [globird, agl, origin, ea, red] + + with patch( + "custom_components.pricehawk.cdr.ranking_job.get_registry", + AsyncMock(return_value=(endpoints, "live")), + ): + result = asyncio.run(get_competitor_retailers(None, opts)) + assert result[0].brand_name == "GloBird" + names = [r.brand_name for r in result] + assert "AGL Energy" in names + assert "Origin" in names + assert "EnergyAustralia" in names + assert "Red Energy" in names + + def test_dedup_when_current_retailer_is_a_competitor(self): + """User on AGL: AGL appears once, not twice.""" + opts = {"cdr_plan": {"data": {"brand": "AGL"}}} + agl = _retailer("AGL Energy", brand_id="200") + origin = _retailer("Origin", brand_id="300") + endpoints = [agl, origin] + + with patch( + "custom_components.pricehawk.cdr.ranking_job.get_registry", + AsyncMock(return_value=(endpoints, "live")), + ): + result = asyncio.run(get_competitor_retailers(None, opts)) + brand_ids = [r.brand_id for r in result] + assert brand_ids.count("200") == 1 + + def test_missing_brand_returns_only_big4(self): + """No `brand` in current cdr_plan: still returns the big-4.""" + opts = {"cdr_plan": {"data": {}}} + agl = _retailer("AGL Energy", brand_id="200") + with patch( + "custom_components.pricehawk.cdr.ranking_job.get_registry", + AsyncMock(return_value=([agl], "live")), + ): + result = asyncio.run(get_competitor_retailers(None, opts)) + assert len(result) == 1 + assert result[0].brand_name == "AGL Energy" + + def test_unmatchable_fragment_silently_skipped(self): + """If a competitor fragment finds no match in the registry, + it's omitted from the result — doesn't raise.""" + opts = {"cdr_plan": {"data": {"brand": "AGL"}}} + with patch( + "custom_components.pricehawk.cdr.ranking_job.get_registry", + AsyncMock(return_value=([], "baked-in")), + ): + result = asyncio.run(get_competitor_retailers(None, opts)) + assert result == [] + + def test_custom_competitor_fragments_override_default(self): + opts = {"cdr_plan": {"data": {}}} + custom = _retailer("Sumo", brand_id="999") + with patch( + "custom_components.pricehawk.cdr.ranking_job.get_registry", + AsyncMock(return_value=([custom], "live")), + ): + result = asyncio.run(get_competitor_retailers( + None, opts, competitor_fragments=("sumo",) + )) + assert len(result) == 1 + assert result[0].brand_name == "Sumo" + + +# --------------------------------------------------------------------------- +# Top-level run_ranking_job +# --------------------------------------------------------------------------- + + +class TestRunRankingJob: + def test_empty_retailers_returns_empty(self): + opts = {"cdr_plan": {"data": {}}} + with patch( + "custom_components.pricehawk.cdr.ranking_job.get_registry", + AsyncMock(return_value=([], "baked-in")), + ): + result = asyncio.run(run_ranking_job(None, opts)) + assert result == [] + + def test_happy_path_forwards_geography_to_rank_alternatives(self): + opts = { + "cdr_postcode": "3104", + "cdr_plan": { + "data": { + "brand": "GloBird", + "geography": {"distributors": ["United Energy"]}, + } + }, + } + globird = _retailer("GloBird", brand_id="100") + ranked = [{"planId": "BEST"}] + + with ( + patch( + "custom_components.pricehawk.cdr.ranking_job.get_registry", + AsyncMock(return_value=([globird], "live")), + ), + patch( + "custom_components.pricehawk.cdr.ranking_job.rank_alternatives", + AsyncMock(return_value=ranked), + ) as rank_mock, + ): + result = asyncio.run(run_ranking_job(None, opts)) + assert result == ranked + assert rank_mock.call_args.kwargs["postcode"] == "3104" + assert rank_mock.call_args.kwargs["distributor"] == "United Energy" + + def test_passes_plan_cache_through(self): + opts = {"cdr_plan": {"data": {"brand": "AGL"}}} + agl = _retailer("AGL Energy", brand_id="200") + my_cache: dict = {"P1": {"cached": True}} + + with ( + patch( + "custom_components.pricehawk.cdr.ranking_job.get_registry", + AsyncMock(return_value=([agl], "live")), + ), + patch( + "custom_components.pricehawk.cdr.ranking_job.rank_alternatives", + AsyncMock(return_value=[]), + ) as rank_mock, + ): + asyncio.run(run_ranking_job(None, opts, plan_cache=my_cache)) + assert rank_mock.call_args.kwargs["cache"] is my_cache + + def test_top_k_forwarded(self): + opts = {"cdr_plan": {"data": {"brand": "AGL"}}} + agl = _retailer("AGL Energy", brand_id="200") + with ( + patch( + "custom_components.pricehawk.cdr.ranking_job.get_registry", + AsyncMock(return_value=([agl], "live")), + ), + patch( + "custom_components.pricehawk.cdr.ranking_job.rank_alternatives", + AsyncMock(return_value=[]), + ) as rank_mock, + ): + asyncio.run(run_ranking_job(None, opts, top_k=5)) + assert rank_mock.call_args.kwargs["top_k"] == 5