Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 62 additions & 5 deletions homeassistant/components/google_assistant_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,24 @@
from __future__ import annotations

import aiohttp
from gassist_text import TextAssistant
from google.oauth2.credentials import Credentials
import voluptuous as vol

from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import CONF_NAME, Platform
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform
from homeassistant.core import Context, HomeAssistant, ServiceCall
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import discovery
from homeassistant.helpers import discovery, intent
from homeassistant.helpers.config_entry_oauth2_flow import (
OAuth2Session,
async_get_config_entry_implementation,
)
from homeassistant.helpers.typing import ConfigType

from .const import DOMAIN
from .helpers import async_send_text_commands
from .const import CONF_LANGUAGE_CODE, DOMAIN
from .helpers import async_send_text_commands, default_language_code

SERVICE_SEND_TEXT_COMMAND = "send_text_command"
SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND = "command"
Expand Down Expand Up @@ -57,6 +60,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = session

await async_setup_service(hass)
agent = GoogleAssistantConversationAgent(hass, entry)
conversation.async_set_agent(hass, agent)
Comment thread
tronikos marked this conversation as resolved.
Outdated

return True

Expand Down Expand Up @@ -90,3 +95,55 @@ async def send_text_command(call: ServiceCall) -> None:
send_text_command,
schema=SERVICE_SEND_TEXT_COMMAND_SCHEMA,
)


class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
"""Google Assistant SDK conversation agent."""

def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.hass = hass
self.entry = entry
self.assistant: TextAssistant | None = None
self.session: OAuth2Session | None = None

@property
def attribution(self):
"""Return the attribution."""
return {
"name": "Powered by Google Assistant SDK",
"url": "https://www.home-assistant.io/integrations/google_assistant_sdk/",
}

async def async_process(
self,
text: str,
context: Context,
conversation_id: str | None = None,
language: str | None = None,
) -> conversation.ConversationResult | None:
"""Process a sentence."""
if self.session:
session = self.session
else:
session = self.hass.data[DOMAIN].get(self.entry.entry_id)
self.session = session
if not session.valid_token:
await session.async_ensure_token_valid()
self.assistant = None
if not self.assistant:
credentials = Credentials(session.token[CONF_ACCESS_TOKEN])
language_code = self.entry.options.get(
CONF_LANGUAGE_CODE, default_language_code(self.hass)
)
self.assistant = TextAssistant(credentials, language_code)

resp = self.assistant.assist(text)
text_response = resp[0]

language = language or self.hass.config.language
intent_response = intent.IntentResponse(language=language)
intent_response.async_set_speech(text_response)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way conversation_id is used that when a conversation is initiated, it is set to None.

Then in your response you return a conversation ID such that follow-up responses work.

Now in this case it doesn't seem like you're passing conversation_id so you should set conversation_id=None here, as it's not used.

Suggested change
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=None

Now in your screenshot I did notice that it's allowing follow-up commands, so I am not sure how that works then 🤷

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both DefaultAgent and AlmondAgent (the only other two agents) do exactly the same thing. conversation_id starts with None and is never modified.

Anyway for Google Assistant we cannot really have many concurrent conversations, see https://github.com/tronikos/gassist_text/blob/main/src/google/assistant/embedded/v1alpha2/embedded_assistant.proto#L314-L318

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah neither can do conversations.

)
6 changes: 5 additions & 1 deletion tests/components/google_assistant_sdk/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ async def func() -> None:
class ExpectedCredentials:
"""Assert credentials have the expected access token."""

def __init__(self, expected_access_token: str = ACCESS_TOKEN) -> None:
"""Initialize ExpectedCredentials."""
self.expected_access_token = expected_access_token

def __eq__(self, other: Credentials):
"""Return true if credentials have the expected access token."""
return other.token == ACCESS_TOKEN
return other.token == self.expected_access_token
92 changes: 92 additions & 0 deletions tests/components/google_assistant_sdk/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from homeassistant.components.google_assistant_sdk import DOMAIN
from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component

from .conftest import ComponentSetup, ExpectedCredentials

Expand Down Expand Up @@ -177,3 +178,94 @@ async def test_send_text_command_expired_token_refresh_failure(
)

assert any(entry.async_get_active_flows(hass, {"reauth"})) == requires_reauth


async def test_conversation_agent(
hass: HomeAssistant,
setup_integration: ComponentSetup,
) -> None:
"""Test GoogleAssistantConversationAgent."""
await setup_integration()

assert await async_setup_component(hass, "conversation", {})

text1 = "tell me a joke"
text2 = "tell me another one"
with patch(
"homeassistant.components.google_assistant_sdk.TextAssistant"
) as mock_text_assistant:
await hass.services.async_call(
"conversation",
"process",
{"text": text1},
blocking=True,
)
await hass.services.async_call(
"conversation",
"process",
{"text": text2},
blocking=True,
)

# Assert constructor is called only once since it's reused across requests
assert mock_text_assistant.call_count == 1
mock_text_assistant.assert_called_once_with(ExpectedCredentials(), "en-US")
mock_text_assistant.assert_has_calls([call().assist(text1)])
mock_text_assistant.assert_has_calls([call().assist(text2)])


async def test_conversation_agent_refresh_token(
hass: HomeAssistant,
setup_integration: ComponentSetup,
aioclient_mock: AiohttpClientMocker,
) -> None:
"""Test GoogleAssistantConversationAgent when token is expired."""
await setup_integration()

assert await async_setup_component(hass, "conversation", {})

entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
entry = entries[0]
assert entry.state is ConfigEntryState.LOADED

text1 = "tell me a joke"
text2 = "tell me another one"
with patch(
"homeassistant.components.google_assistant_sdk.TextAssistant"
) as mock_text_assistant:
await hass.services.async_call(
"conversation",
"process",
{"text": text1},
blocking=True,
)

# Expire the token between requests
entry.data["token"]["expires_at"] = time.time() - 3600
updated_access_token = "updated-access-token"
aioclient_mock.post(
"https://oauth2.googleapis.com/token",
json={
"access_token": updated_access_token,
"refresh_token": "updated-refresh-token",
"expires_at": time.time() + 3600,
"expires_in": 3600,
},
)

await hass.services.async_call(
"conversation",
"process",
{"text": text2},
blocking=True,
)

# Assert constructor is called twice since the token was expired
assert mock_text_assistant.call_count == 2
mock_text_assistant.assert_has_calls([call(ExpectedCredentials(), "en-US")])
mock_text_assistant.assert_has_calls(
[call(ExpectedCredentials(updated_access_token), "en-US")]
)
mock_text_assistant.assert_has_calls([call().assist(text1)])
mock_text_assistant.assert_has_calls([call().assist(text2)])