From c044ca32b974e1b06f9e531b39f504fe97d77f51 Mon Sep 17 00:00:00 2001 From: tronikos Date: Mon, 9 Jan 2023 08:50:16 +0000 Subject: [PATCH 1/5] Google Assistant SDK conversation agent --- .../google_assistant_sdk/__init__.py | 62 +++++++++++++++++-- 1 file changed, 57 insertions(+), 5 deletions(-) diff --git a/homeassistant/components/google_assistant_sdk/__init__.py b/homeassistant/components/google_assistant_sdk/__init__.py index 119ba9e1d27c6..b15b3c7d0c649 100644 --- a/homeassistant/components/google_assistant_sdk/__init__.py +++ b/homeassistant/components/google_assistant_sdk/__init__.py @@ -2,21 +2,24 @@ from __future__ import annotations import aiohttp +from gassist_text import TextAssistant +from google.oauth2.credentials import Credentials import voluptuous as vol +from homeassistant.components import conversation from homeassistant.config_entries import ConfigEntry, ConfigEntryState -from homeassistant.const import CONF_NAME, Platform -from homeassistant.core import HomeAssistant, ServiceCall +from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform +from homeassistant.core import Context, HomeAssistant, ServiceCall from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady -from homeassistant.helpers import discovery +from homeassistant.helpers import discovery, intent from homeassistant.helpers.config_entry_oauth2_flow import ( OAuth2Session, async_get_config_entry_implementation, ) from homeassistant.helpers.typing import ConfigType -from .const import DOMAIN -from .helpers import async_send_text_commands +from .const import CONF_LANGUAGE_CODE, DOMAIN +from .helpers import async_send_text_commands, default_language_code SERVICE_SEND_TEXT_COMMAND = "send_text_command" SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND = "command" @@ -57,6 +60,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: hass.data.setdefault(DOMAIN, {})[entry.entry_id] = session await async_setup_service(hass) + agent = GoogleAssistantConversationAgent(hass, entry) + conversation.async_set_agent(hass, agent) return True @@ -90,3 +95,50 @@ async def send_text_command(call: ServiceCall) -> None: send_text_command, schema=SERVICE_SEND_TEXT_COMMAND_SCHEMA, ) + + +class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent): + """Google Assistant SDK conversation agent.""" + + def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: + """Initialize the agent.""" + self.hass = hass + self.entry = entry + self.assistant: TextAssistant | None = None + + @property + def attribution(self): + """Return the attribution.""" + return { + "name": "Powered by Google Assistant SDK", + "url": "https://www.home-assistant.io/integrations/google_assistant_sdk/", + } + + async def async_process( + self, + text: str, + context: Context, + conversation_id: str | None = None, + language: str | None = None, + ) -> conversation.ConversationResult | None: + """Process a sentence.""" + if not self.assistant: + language_code = self.entry.options.get( + CONF_LANGUAGE_CODE, default_language_code(self.hass) + ) + session: OAuth2Session = self.hass.data[DOMAIN].get(self.entry.entry_id) + credentials = Credentials( + session.token[CONF_ACCESS_TOKEN], + refresh_token=session.token["refresh_token"], + ) + self.assistant = TextAssistant(credentials, language_code) + + resp = self.assistant.assist(text) + text_response = resp[0] + + language = language or self.hass.config.language + intent_response = intent.IntentResponse(language=language) + intent_response.async_set_speech(text_response) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) From 3d802991201c5c0ef289152ef5ee5234e14c5b6d Mon Sep 17 00:00:00 2001 From: tronikos Date: Mon, 9 Jan 2023 09:54:57 +0000 Subject: [PATCH 2/5] refresh token --- .../components/google_assistant_sdk/__init__.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/homeassistant/components/google_assistant_sdk/__init__.py b/homeassistant/components/google_assistant_sdk/__init__.py index b15b3c7d0c649..b7be4d3f9f7ed 100644 --- a/homeassistant/components/google_assistant_sdk/__init__.py +++ b/homeassistant/components/google_assistant_sdk/__init__.py @@ -100,11 +100,13 @@ async def send_text_command(call: ServiceCall) -> None: class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent): """Google Assistant SDK conversation agent.""" + assistant: TextAssistant + session: OAuth2Session + def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: """Initialize the agent.""" self.hass = hass self.entry = entry - self.assistant: TextAssistant | None = None @property def attribution(self): @@ -122,15 +124,16 @@ async def async_process( language: str | None = None, ) -> conversation.ConversationResult | None: """Process a sentence.""" + if not self.session: + self.session = self.hass.data[DOMAIN].get(self.entry.entry_id) + if not self.session.valid_token: + await self.session.async_ensure_token_valid() + self.assistant = None if not self.assistant: + credentials = Credentials(self.session.token[CONF_ACCESS_TOKEN]) language_code = self.entry.options.get( CONF_LANGUAGE_CODE, default_language_code(self.hass) ) - session: OAuth2Session = self.hass.data[DOMAIN].get(self.entry.entry_id) - credentials = Credentials( - session.token[CONF_ACCESS_TOKEN], - refresh_token=session.token["refresh_token"], - ) self.assistant = TextAssistant(credentials, language_code) resp = self.assistant.assist(text) From 384709acd69aa5f122494daa68ddd9bb8f54617d Mon Sep 17 00:00:00 2001 From: tronikos Date: Mon, 9 Jan 2023 10:21:59 +0000 Subject: [PATCH 3/5] fix session --- .../google_assistant_sdk/__init__.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/homeassistant/components/google_assistant_sdk/__init__.py b/homeassistant/components/google_assistant_sdk/__init__.py index b7be4d3f9f7ed..72855c46a8c96 100644 --- a/homeassistant/components/google_assistant_sdk/__init__.py +++ b/homeassistant/components/google_assistant_sdk/__init__.py @@ -100,13 +100,12 @@ async def send_text_command(call: ServiceCall) -> None: class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent): """Google Assistant SDK conversation agent.""" - assistant: TextAssistant - session: OAuth2Session - def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: """Initialize the agent.""" self.hass = hass self.entry = entry + self.assistant: TextAssistant | None = None + self.session: OAuth2Session | None = None @property def attribution(self): @@ -124,13 +123,16 @@ async def async_process( language: str | None = None, ) -> conversation.ConversationResult | None: """Process a sentence.""" - if not self.session: - self.session = self.hass.data[DOMAIN].get(self.entry.entry_id) - if not self.session.valid_token: - await self.session.async_ensure_token_valid() + if self.session: + session = self.session + else: + session = self.hass.data[DOMAIN].get(self.entry.entry_id) + self.session = session + if not session.valid_token: + await session.async_ensure_token_valid() self.assistant = None if not self.assistant: - credentials = Credentials(self.session.token[CONF_ACCESS_TOKEN]) + credentials = Credentials(session.token[CONF_ACCESS_TOKEN]) language_code = self.entry.options.get( CONF_LANGUAGE_CODE, default_language_code(self.hass) ) From 9851598a071782898e2f0f515ac4fa00b67e9205 Mon Sep 17 00:00:00 2001 From: tronikos Date: Mon, 9 Jan 2023 11:26:55 +0000 Subject: [PATCH 4/5] Add tests --- .../google_assistant_sdk/conftest.py | 6 +- .../google_assistant_sdk/test_init.py | 92 +++++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/tests/components/google_assistant_sdk/conftest.py b/tests/components/google_assistant_sdk/conftest.py index 9730c0fef1781..207ceccb342b7 100644 --- a/tests/components/google_assistant_sdk/conftest.py +++ b/tests/components/google_assistant_sdk/conftest.py @@ -87,6 +87,10 @@ async def func() -> None: class ExpectedCredentials: """Assert credentials have the expected access token.""" + def __init__(self, expected_access_token: str = ACCESS_TOKEN) -> None: + """Initialize ExpectedCredentials.""" + self.expected_access_token = expected_access_token + def __eq__(self, other: Credentials): """Return true if credentials have the expected access token.""" - return other.token == ACCESS_TOKEN + return other.token == self.expected_access_token diff --git a/tests/components/google_assistant_sdk/test_init.py b/tests/components/google_assistant_sdk/test_init.py index afc5e77042fbe..09f402de6ed8d 100644 --- a/tests/components/google_assistant_sdk/test_init.py +++ b/tests/components/google_assistant_sdk/test_init.py @@ -9,6 +9,7 @@ from homeassistant.components.google_assistant_sdk import DOMAIN from homeassistant.config_entries import ConfigEntryState from homeassistant.core import HomeAssistant +from homeassistant.setup import async_setup_component from .conftest import ComponentSetup, ExpectedCredentials @@ -177,3 +178,94 @@ async def test_send_text_command_expired_token_refresh_failure( ) assert any(entry.async_get_active_flows(hass, {"reauth"})) == requires_reauth + + +async def test_conversation_agent( + hass: HomeAssistant, + setup_integration: ComponentSetup, +) -> None: + """Test GoogleAssistantConversationAgent.""" + await setup_integration() + + assert await async_setup_component(hass, "conversation", {}) + + text1 = "tell me a joke" + text2 = "tell me another one" + with patch( + "homeassistant.components.google_assistant_sdk.TextAssistant" + ) as mock_text_assistant: + await hass.services.async_call( + "conversation", + "process", + {"text": text1}, + blocking=True, + ) + await hass.services.async_call( + "conversation", + "process", + {"text": text2}, + blocking=True, + ) + + # Assert constructor is called only once since it's reused across requests + assert mock_text_assistant.call_count == 1 + mock_text_assistant.assert_called_once_with(ExpectedCredentials(), "en-US") + mock_text_assistant.assert_has_calls([call().assist(text1)]) + mock_text_assistant.assert_has_calls([call().assist(text2)]) + + +async def test_conversation_agent_refresh_token( + hass: HomeAssistant, + setup_integration: ComponentSetup, + aioclient_mock: AiohttpClientMocker, +) -> None: + """Test GoogleAssistantConversationAgent when token is expired.""" + await setup_integration() + + assert await async_setup_component(hass, "conversation", {}) + + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + entry = entries[0] + assert entry.state is ConfigEntryState.LOADED + + text1 = "tell me a joke" + text2 = "tell me another one" + with patch( + "homeassistant.components.google_assistant_sdk.TextAssistant" + ) as mock_text_assistant: + await hass.services.async_call( + "conversation", + "process", + {"text": text1}, + blocking=True, + ) + + # Expire the token between requests + entry.data["token"]["expires_at"] = time.time() - 3600 + updated_access_token = "updated-access-token" + aioclient_mock.post( + "https://oauth2.googleapis.com/token", + json={ + "access_token": updated_access_token, + "refresh_token": "updated-refresh-token", + "expires_at": time.time() + 3600, + "expires_in": 3600, + }, + ) + + await hass.services.async_call( + "conversation", + "process", + {"text": text2}, + blocking=True, + ) + + # Assert constructor is called twice since the token was expired + assert mock_text_assistant.call_count == 2 + mock_text_assistant.assert_has_calls([call(ExpectedCredentials(), "en-US")]) + mock_text_assistant.assert_has_calls( + [call(ExpectedCredentials(updated_access_token), "en-US")] + ) + mock_text_assistant.assert_has_calls([call().assist(text1)]) + mock_text_assistant.assert_has_calls([call().assist(text2)]) From 9de39b8bd9d655e05d3444f79bb87b4a8e72041c Mon Sep 17 00:00:00 2001 From: tronikos Date: Mon, 9 Jan 2023 19:59:49 +0000 Subject: [PATCH 5/5] Add option to enable conversation agent --- .../google_assistant_sdk/__init__.py | 16 +++++-- .../google_assistant_sdk/config_flow.py | 14 +++++- .../components/google_assistant_sdk/const.py | 2 + .../google_assistant_sdk/strings.json | 4 +- .../google_assistant_sdk/translations/en.json | 4 +- .../google_assistant_sdk/test_config_flow.py | 44 +++++++++++++++---- .../google_assistant_sdk/test_init.py | 13 ++++++ 7 files changed, 82 insertions(+), 15 deletions(-) diff --git a/homeassistant/components/google_assistant_sdk/__init__.py b/homeassistant/components/google_assistant_sdk/__init__.py index 72855c46a8c96..59c065ecbb6b1 100644 --- a/homeassistant/components/google_assistant_sdk/__init__.py +++ b/homeassistant/components/google_assistant_sdk/__init__.py @@ -18,7 +18,7 @@ ) from homeassistant.helpers.typing import ConfigType -from .const import CONF_LANGUAGE_CODE, DOMAIN +from .const import CONF_ENABLE_CONVERSATION_AGENT, CONF_LANGUAGE_CODE, DOMAIN from .helpers import async_send_text_commands, default_language_code SERVICE_SEND_TEXT_COMMAND = "send_text_command" @@ -60,8 +60,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: hass.data.setdefault(DOMAIN, {})[entry.entry_id] = session await async_setup_service(hass) - agent = GoogleAssistantConversationAgent(hass, entry) - conversation.async_set_agent(hass, agent) + + entry.async_on_unload(entry.add_update_listener(update_listener)) + await update_listener(hass, entry) return True @@ -97,6 +98,15 @@ async def send_text_command(call: ServiceCall) -> None: ) +async def update_listener(hass, entry): + """Handle options update.""" + if entry.options.get(CONF_ENABLE_CONVERSATION_AGENT, False): + agent = GoogleAssistantConversationAgent(hass, entry) + conversation.async_set_agent(hass, agent) + else: + conversation.async_set_agent(hass, None) + + class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent): """Google Assistant SDK conversation agent.""" diff --git a/homeassistant/components/google_assistant_sdk/config_flow.py b/homeassistant/components/google_assistant_sdk/config_flow.py index b4f617ca02925..b93a3be93f2fd 100644 --- a/homeassistant/components/google_assistant_sdk/config_flow.py +++ b/homeassistant/components/google_assistant_sdk/config_flow.py @@ -13,7 +13,13 @@ from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers import config_entry_oauth2_flow -from .const import CONF_LANGUAGE_CODE, DEFAULT_NAME, DOMAIN, SUPPORTED_LANGUAGE_CODES +from .const import ( + CONF_ENABLE_CONVERSATION_AGENT, + CONF_LANGUAGE_CODE, + DEFAULT_NAME, + DOMAIN, + SUPPORTED_LANGUAGE_CODES, +) from .helpers import default_language_code _LOGGER = logging.getLogger(__name__) @@ -108,6 +114,12 @@ async def async_step_init( CONF_LANGUAGE_CODE, default=self.config_entry.options.get(CONF_LANGUAGE_CODE), ): vol.In(SUPPORTED_LANGUAGE_CODES), + vol.Required( + CONF_ENABLE_CONVERSATION_AGENT, + default=self.config_entry.options.get( + CONF_ENABLE_CONVERSATION_AGENT + ), + ): bool, } ), ) diff --git a/homeassistant/components/google_assistant_sdk/const.py b/homeassistant/components/google_assistant_sdk/const.py index 8458145caace9..1b77b58d0fbcb 100644 --- a/homeassistant/components/google_assistant_sdk/const.py +++ b/homeassistant/components/google_assistant_sdk/const.py @@ -24,3 +24,5 @@ "ko-KR", "pt-BR", ] + +CONF_ENABLE_CONVERSATION_AGENT: Final = "enable_conversation_agent" diff --git a/homeassistant/components/google_assistant_sdk/strings.json b/homeassistant/components/google_assistant_sdk/strings.json index 66a2b975b5e15..d4c85be91e50c 100644 --- a/homeassistant/components/google_assistant_sdk/strings.json +++ b/homeassistant/components/google_assistant_sdk/strings.json @@ -31,8 +31,10 @@ "step": { "init": { "data": { + "enable_conversation_agent": "Enable the conversation agent", "language_code": "Language code" - } + }, + "description": "Set language for interactions with Google Assistant and whether you want to enable the conversation agent." } } }, diff --git a/homeassistant/components/google_assistant_sdk/translations/en.json b/homeassistant/components/google_assistant_sdk/translations/en.json index 36d28427ca27d..cd23f86e2e07c 100644 --- a/homeassistant/components/google_assistant_sdk/translations/en.json +++ b/homeassistant/components/google_assistant_sdk/translations/en.json @@ -34,8 +34,10 @@ "step": { "init": { "data": { + "enable_conversation_agent": "Enable conversation agent", "language_code": "Language code" - } + }, + "description": "Set language for interactions with Google Assistant and whether you want to enable the conversation agent." } } } diff --git a/tests/components/google_assistant_sdk/test_config_flow.py b/tests/components/google_assistant_sdk/test_config_flow.py index af5f0e73c75e8..a0f22d814b1f7 100644 --- a/tests/components/google_assistant_sdk/test_config_flow.py +++ b/tests/components/google_assistant_sdk/test_config_flow.py @@ -221,39 +221,65 @@ async def test_options_flow( assert result["type"] == "form" assert result["step_id"] == "init" data_schema = result["data_schema"].schema - assert set(data_schema) == {"language_code"} + assert set(data_schema) == {"enable_conversation_agent", "language_code"} result = await hass.config_entries.options.async_configure( result["flow_id"], - user_input={"language_code": "es-ES"}, + user_input={"enable_conversation_agent": False, "language_code": "es-ES"}, ) assert result["type"] == "create_entry" - assert config_entry.options == {"language_code": "es-ES"} + assert config_entry.options == { + "enable_conversation_agent": False, + "language_code": "es-ES", + } # Retrigger options flow, not change language result = await hass.config_entries.options.async_init(config_entry.entry_id) assert result["type"] == "form" assert result["step_id"] == "init" data_schema = result["data_schema"].schema - assert set(data_schema) == {"language_code"} + assert set(data_schema) == {"enable_conversation_agent", "language_code"} result = await hass.config_entries.options.async_configure( result["flow_id"], - user_input={"language_code": "es-ES"}, + user_input={"enable_conversation_agent": False, "language_code": "es-ES"}, ) assert result["type"] == "create_entry" - assert config_entry.options == {"language_code": "es-ES"} + assert config_entry.options == { + "enable_conversation_agent": False, + "language_code": "es-ES", + } # Retrigger options flow, change language result = await hass.config_entries.options.async_init(config_entry.entry_id) assert result["type"] == "form" assert result["step_id"] == "init" data_schema = result["data_schema"].schema - assert set(data_schema) == {"language_code"} + assert set(data_schema) == {"enable_conversation_agent", "language_code"} result = await hass.config_entries.options.async_configure( result["flow_id"], - user_input={"language_code": "en-US"}, + user_input={"enable_conversation_agent": False, "language_code": "en-US"}, ) assert result["type"] == "create_entry" - assert config_entry.options == {"language_code": "en-US"} + assert config_entry.options == { + "enable_conversation_agent": False, + "language_code": "en-US", + } + + # Retrigger options flow, enable conversation agent + result = await hass.config_entries.options.async_init(config_entry.entry_id) + assert result["type"] == "form" + assert result["step_id"] == "init" + data_schema = result["data_schema"].schema + assert set(data_schema) == {"enable_conversation_agent", "language_code"} + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={"enable_conversation_agent": True, "language_code": "en-US"}, + ) + assert result["type"] == "create_entry" + assert config_entry.options == { + "enable_conversation_agent": True, + "language_code": "en-US", + } diff --git a/tests/components/google_assistant_sdk/test_init.py b/tests/components/google_assistant_sdk/test_init.py index 09f402de6ed8d..b93f83feda726 100644 --- a/tests/components/google_assistant_sdk/test_init.py +++ b/tests/components/google_assistant_sdk/test_init.py @@ -189,6 +189,15 @@ async def test_conversation_agent( assert await async_setup_component(hass, "conversation", {}) + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + entry = entries[0] + assert entry.state is ConfigEntryState.LOADED + hass.config_entries.async_update_entry( + entry, options={"enable_conversation_agent": True} + ) + await hass.async_block_till_done() + text1 = "tell me a joke" text2 = "tell me another one" with patch( @@ -228,6 +237,10 @@ async def test_conversation_agent_refresh_token( assert len(entries) == 1 entry = entries[0] assert entry.state is ConfigEntryState.LOADED + hass.config_entries.async_update_entry( + entry, options={"enable_conversation_agent": True} + ) + await hass.async_block_till_done() text1 = "tell me a joke" text2 = "tell me another one"