Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 72 additions & 5 deletions homeassistant/components/google_assistant_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,24 @@
from __future__ import annotations

import aiohttp
from gassist_text import TextAssistant
from google.oauth2.credentials import Credentials
import voluptuous as vol

from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import CONF_NAME, Platform
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform
from homeassistant.core import Context, HomeAssistant, ServiceCall
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import discovery
from homeassistant.helpers import discovery, intent
from homeassistant.helpers.config_entry_oauth2_flow import (
OAuth2Session,
async_get_config_entry_implementation,
)
from homeassistant.helpers.typing import ConfigType

from .const import DOMAIN
from .helpers import async_send_text_commands
from .const import CONF_ENABLE_CONVERSATION_AGENT, CONF_LANGUAGE_CODE, DOMAIN
from .helpers import async_send_text_commands, default_language_code

SERVICE_SEND_TEXT_COMMAND = "send_text_command"
SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND = "command"
Expand Down Expand Up @@ -58,6 +61,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:

await async_setup_service(hass)

entry.async_on_unload(entry.add_update_listener(update_listener))
await update_listener(hass, entry)

return True


Expand Down Expand Up @@ -90,3 +96,64 @@ async def send_text_command(call: ServiceCall) -> None:
send_text_command,
schema=SERVICE_SEND_TEXT_COMMAND_SCHEMA,
)


async def update_listener(hass, entry):
"""Handle options update."""
if entry.options.get(CONF_ENABLE_CONVERSATION_AGENT, False):
agent = GoogleAssistantConversationAgent(hass, entry)
conversation.async_set_agent(hass, agent)
else:
conversation.async_set_agent(hass, None)


class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
"""Google Assistant SDK conversation agent."""

def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.hass = hass
self.entry = entry
self.assistant: TextAssistant | None = None
self.session: OAuth2Session | None = None

@property
def attribution(self):
"""Return the attribution."""
return {
"name": "Powered by Google Assistant SDK",
"url": "https://www.home-assistant.io/integrations/google_assistant_sdk/",
}

async def async_process(
self,
text: str,
context: Context,
conversation_id: str | None = None,
language: str | None = None,
) -> conversation.ConversationResult | None:
"""Process a sentence."""
if self.session:
session = self.session
else:
session = self.hass.data[DOMAIN].get(self.entry.entry_id)
self.session = session
if not session.valid_token:
await session.async_ensure_token_valid()
self.assistant = None
if not self.assistant:
credentials = Credentials(session.token[CONF_ACCESS_TOKEN])
language_code = self.entry.options.get(
CONF_LANGUAGE_CODE, default_language_code(self.hass)
)
self.assistant = TextAssistant(credentials, language_code)

resp = self.assistant.assist(text)
text_response = resp[0]

language = language or self.hass.config.language
intent_response = intent.IntentResponse(language=language)
intent_response.async_set_speech(text_response)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way conversation_id is used that when a conversation is initiated, it is set to None.

Then in your response you return a conversation ID such that follow-up responses work.

Now in this case it doesn't seem like you're passing conversation_id so you should set conversation_id=None here, as it's not used.

Suggested change
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=None

Now in your screenshot I did notice that it's allowing follow-up commands, so I am not sure how that works then 🤷

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both DefaultAgent and AlmondAgent (the only other two agents) do exactly the same thing. conversation_id starts with None and is never modified.

Anyway for Google Assistant we cannot really have many concurrent conversations, see https://github.com/tronikos/gassist_text/blob/main/src/google/assistant/embedded/v1alpha2/embedded_assistant.proto#L314-L318

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah neither can do conversations.

)
14 changes: 13 additions & 1 deletion homeassistant/components/google_assistant_sdk/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import config_entry_oauth2_flow

from .const import CONF_LANGUAGE_CODE, DEFAULT_NAME, DOMAIN, SUPPORTED_LANGUAGE_CODES
from .const import (
CONF_ENABLE_CONVERSATION_AGENT,
CONF_LANGUAGE_CODE,
DEFAULT_NAME,
DOMAIN,
SUPPORTED_LANGUAGE_CODES,
)
from .helpers import default_language_code

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -108,6 +114,12 @@ async def async_step_init(
CONF_LANGUAGE_CODE,
default=self.config_entry.options.get(CONF_LANGUAGE_CODE),
): vol.In(SUPPORTED_LANGUAGE_CODES),
vol.Required(
CONF_ENABLE_CONVERSATION_AGENT,
Comment thread
tronikos marked this conversation as resolved.
default=self.config_entry.options.get(
CONF_ENABLE_CONVERSATION_AGENT
),
): bool,
}
),
)
2 changes: 2 additions & 0 deletions homeassistant/components/google_assistant_sdk/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@
"ko-KR",
"pt-BR",
]

CONF_ENABLE_CONVERSATION_AGENT: Final = "enable_conversation_agent"
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
"step": {
"init": {
"data": {
"enable_conversation_agent": "Enable the conversation agent",
"language_code": "Language code"
}
},
"description": "Set language for interactions with Google Assistant and whether you want to enable the conversation agent."
}
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
"step": {
"init": {
"data": {
"enable_conversation_agent": "Enable conversation agent",
"language_code": "Language code"
}
},
"description": "Set language for interactions with Google Assistant and whether you want to enable the conversation agent."
}
}
}
Expand Down
6 changes: 5 additions & 1 deletion tests/components/google_assistant_sdk/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ async def func() -> None:
class ExpectedCredentials:
"""Assert credentials have the expected access token."""

def __init__(self, expected_access_token: str = ACCESS_TOKEN) -> None:
"""Initialize ExpectedCredentials."""
self.expected_access_token = expected_access_token

def __eq__(self, other: Credentials):
"""Return true if credentials have the expected access token."""
return other.token == ACCESS_TOKEN
return other.token == self.expected_access_token
44 changes: 35 additions & 9 deletions tests/components/google_assistant_sdk/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,39 +221,65 @@ async def test_options_flow(
assert result["type"] == "form"
assert result["step_id"] == "init"
data_schema = result["data_schema"].schema
assert set(data_schema) == {"language_code"}
assert set(data_schema) == {"enable_conversation_agent", "language_code"}

result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"language_code": "es-ES"},
user_input={"enable_conversation_agent": False, "language_code": "es-ES"},
)
assert result["type"] == "create_entry"
assert config_entry.options == {"language_code": "es-ES"}
assert config_entry.options == {
"enable_conversation_agent": False,
"language_code": "es-ES",
}

# Retrigger options flow, not change language
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
assert result["step_id"] == "init"
data_schema = result["data_schema"].schema
assert set(data_schema) == {"language_code"}
assert set(data_schema) == {"enable_conversation_agent", "language_code"}

result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"language_code": "es-ES"},
user_input={"enable_conversation_agent": False, "language_code": "es-ES"},
)
assert result["type"] == "create_entry"
assert config_entry.options == {"language_code": "es-ES"}
assert config_entry.options == {
"enable_conversation_agent": False,
"language_code": "es-ES",
}

# Retrigger options flow, change language
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
assert result["step_id"] == "init"
data_schema = result["data_schema"].schema
assert set(data_schema) == {"language_code"}
assert set(data_schema) == {"enable_conversation_agent", "language_code"}

result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"language_code": "en-US"},
user_input={"enable_conversation_agent": False, "language_code": "en-US"},
)
assert result["type"] == "create_entry"
assert config_entry.options == {"language_code": "en-US"}
assert config_entry.options == {
"enable_conversation_agent": False,
"language_code": "en-US",
}

# Retrigger options flow, enable conversation agent
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
assert result["step_id"] == "init"
data_schema = result["data_schema"].schema
assert set(data_schema) == {"enable_conversation_agent", "language_code"}

result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"enable_conversation_agent": True, "language_code": "en-US"},
)
assert result["type"] == "create_entry"
assert config_entry.options == {
"enable_conversation_agent": True,
"language_code": "en-US",
}
105 changes: 105 additions & 0 deletions tests/components/google_assistant_sdk/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from homeassistant.components.google_assistant_sdk import DOMAIN
from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component

from .conftest import ComponentSetup, ExpectedCredentials

Expand Down Expand Up @@ -177,3 +178,107 @@ async def test_send_text_command_expired_token_refresh_failure(
)

assert any(entry.async_get_active_flows(hass, {"reauth"})) == requires_reauth


async def test_conversation_agent(
hass: HomeAssistant,
setup_integration: ComponentSetup,
) -> None:
"""Test GoogleAssistantConversationAgent."""
await setup_integration()

assert await async_setup_component(hass, "conversation", {})

entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
entry = entries[0]
assert entry.state is ConfigEntryState.LOADED
hass.config_entries.async_update_entry(
entry, options={"enable_conversation_agent": True}
)
await hass.async_block_till_done()

text1 = "tell me a joke"
text2 = "tell me another one"
with patch(
"homeassistant.components.google_assistant_sdk.TextAssistant"
) as mock_text_assistant:
await hass.services.async_call(
"conversation",
"process",
{"text": text1},
blocking=True,
)
await hass.services.async_call(
"conversation",
"process",
{"text": text2},
blocking=True,
)

# Assert constructor is called only once since it's reused across requests
assert mock_text_assistant.call_count == 1
mock_text_assistant.assert_called_once_with(ExpectedCredentials(), "en-US")
mock_text_assistant.assert_has_calls([call().assist(text1)])
mock_text_assistant.assert_has_calls([call().assist(text2)])


async def test_conversation_agent_refresh_token(
hass: HomeAssistant,
setup_integration: ComponentSetup,
aioclient_mock: AiohttpClientMocker,
) -> None:
"""Test GoogleAssistantConversationAgent when token is expired."""
await setup_integration()

assert await async_setup_component(hass, "conversation", {})

entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
entry = entries[0]
assert entry.state is ConfigEntryState.LOADED
hass.config_entries.async_update_entry(
entry, options={"enable_conversation_agent": True}
)
await hass.async_block_till_done()

text1 = "tell me a joke"
text2 = "tell me another one"
with patch(
"homeassistant.components.google_assistant_sdk.TextAssistant"
) as mock_text_assistant:
await hass.services.async_call(
"conversation",
"process",
{"text": text1},
blocking=True,
)

# Expire the token between requests
entry.data["token"]["expires_at"] = time.time() - 3600
updated_access_token = "updated-access-token"
aioclient_mock.post(
"https://oauth2.googleapis.com/token",
json={
"access_token": updated_access_token,
"refresh_token": "updated-refresh-token",
"expires_at": time.time() + 3600,
"expires_in": 3600,
},
)

await hass.services.async_call(
"conversation",
"process",
{"text": text2},
blocking=True,
)

# Assert constructor is called twice since the token was expired
assert mock_text_assistant.call_count == 2
mock_text_assistant.assert_has_calls([call(ExpectedCredentials(), "en-US")])
mock_text_assistant.assert_has_calls(
[call(ExpectedCredentials(updated_access_token), "en-US")]
)
mock_text_assistant.assert_has_calls([call().assist(text1)])
mock_text_assistant.assert_has_calls([call().assist(text2)])