Skip to content

Commit f3aab12

Browse files
authored
feat: saved api_key to keychain for user (#104)
1 parent dfecf82 commit f3aab12

File tree

6 files changed

+238
-79
lines changed

6 files changed

+238
-79
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies = [
1111
"ruamel-yaml>=0.18.6",
1212
"click>=8.1.7",
1313
"prompt-toolkit>=3.0.47",
14+
"keyring>=25.4.1",
1415
]
1516
author = [{ name = "Block", email = "[email protected]" }]
1617
packages = [{ include = "goose", from = "src" }]

src/goose/cli/session.py

+9-45
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,21 @@
1-
import sys
21
import traceback
32
from pathlib import Path
43
from typing import Any, Dict, List, Optional
54

6-
from exchange import Message, ToolResult, ToolUse, Text, Exchange
7-
from exchange.providers.base import MissingProviderEnvVariableError
8-
from exchange.invalid_choice_error import InvalidChoiceError
5+
from exchange import Message, ToolResult, ToolUse, Text
96
from rich import print
10-
from rich.console import RenderableType
11-
from rich.live import Live
127
from rich.markdown import Markdown
138
from rich.panel import Panel
149
from rich.status import Status
1510

16-
from goose.build import build_exchange
17-
from goose.cli.config import PROFILES_CONFIG_PATH, ensure_config, session_path, LOG_PATH
11+
from goose.cli.config import ensure_config, session_path, LOG_PATH
1812
from goose._logger import get_logger, setup_logging
1913
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
20-
from goose.notifier import Notifier
14+
from goose.cli.session_notifier import SessionNotifier
2115
from goose.profile import Profile
2216
from goose.utils import droid, load_plugins
2317
from goose.utils._cost_calculator import get_total_cost_message
18+
from goose.utils._create_exchange import create_exchange
2419
from goose.utils.session_file import read_or_create_file, save_latest_session
2520

2621
RESUME_MESSAGE = "I see we were interrupted. How can I help you?"
@@ -52,24 +47,6 @@ def load_profile(name: Optional[str]) -> Profile:
5247
return profile
5348

5449

55-
class SessionNotifier(Notifier):
56-
def __init__(self, status_indicator: Status) -> None:
57-
self.status_indicator = status_indicator
58-
self.live = Live(self.status_indicator, refresh_per_second=8, transient=True)
59-
60-
def log(self, content: RenderableType) -> None:
61-
print(content)
62-
63-
def status(self, status: str) -> None:
64-
self.status_indicator.update(status)
65-
66-
def start(self) -> None:
67-
self.live.start()
68-
69-
def stop(self) -> None:
70-
self.live.stop()
71-
72-
7350
class Session:
7451
"""A session handler for managing interactions between a user and the Goose exchange
7552
@@ -89,10 +66,12 @@ def __init__(
8966
self.name = droid()
9067
else:
9168
self.name = name
92-
self.profile = profile
69+
self.profile_name = profile
70+
self.prompt_session = GoosePromptSession()
9371
self.status_indicator = Status("", spinner="dots")
9472
self.notifier = SessionNotifier(self.status_indicator)
95-
self.exchange = self._create_exchange()
73+
74+
self.exchange = create_exchange(profile=load_profile(profile), notifier=self.notifier)
9675
setup_logging(log_file_directory=LOG_PATH, log_level=log_level)
9776

9877
self.exchange.messages.extend(self._get_initial_messages())
@@ -102,21 +81,6 @@ def __init__(
10281

10382
self.prompt_session = GoosePromptSession()
10483

105-
def _create_exchange(self) -> Exchange:
106-
try:
107-
return build_exchange(profile=load_profile(self.profile), notifier=self.notifier)
108-
except MissingProviderEnvVariableError as e:
109-
error_message = f"{e.message}. Please set the required environment variable to continue."
110-
print(Panel(error_message, style="red"))
111-
sys.exit(1)
112-
except InvalidChoiceError as e:
113-
error_message = (
114-
f"[bold red]{e.message}[/bold red].\nPlease check your configuration file at {PROFILES_CONFIG_PATH}.\n"
115-
+ "Configuration doc: https://block-open-source.github.io/goose/configuration.html"
116-
)
117-
print(error_message)
118-
sys.exit(1)
119-
12084
def _get_initial_messages(self) -> List[Message]:
12185
messages = self.load_session()
12286

@@ -162,7 +126,7 @@ def run(self) -> None:
162126
Runs the main loop to handle user inputs and responses.
163127
Continues until an empty string is returned from the prompt.
164128
"""
165-
print(f"[dim]starting session | name:[cyan]{self.name}[/] profile:[cyan]{self.profile or 'default'}[/]")
129+
print(f"[dim]starting session | name:[cyan]{self.name}[/] profile:[cyan]{self.profile_name or 'default'}[/]")
166130
print(f"[dim]saving to {self.session_file_path}")
167131
print()
168132
message = self.process_first_message()

src/goose/cli/session_notifier.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from rich.status import Status
2+
from rich.live import Live
3+
from rich.console import RenderableType
4+
from rich import print
5+
6+
from goose.notifier import Notifier
7+
8+
9+
class SessionNotifier(Notifier):
10+
def __init__(self, status_indicator: Status) -> None:
11+
self.status_indicator = status_indicator
12+
self.live = Live(self.status_indicator, refresh_per_second=8, transient=True)
13+
14+
def log(self, content: RenderableType) -> None:
15+
print(content)
16+
17+
def status(self, status: str) -> None:
18+
self.status_indicator.update(status)
19+
20+
def start(self) -> None:
21+
self.live.start()
22+
23+
def stop(self) -> None:
24+
self.live.stop()

src/goose/utils/_create_exchange.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import os
2+
import sys
3+
from typing import Optional
4+
import keyring
5+
6+
from prompt_toolkit import prompt
7+
from prompt_toolkit.shortcuts import confirm
8+
from rich import print
9+
from rich.panel import Panel
10+
11+
from goose.build import build_exchange
12+
from goose.cli.config import PROFILES_CONFIG_PATH
13+
from goose.cli.session_notifier import SessionNotifier
14+
from goose.profile import Profile
15+
from exchange import Exchange
16+
from exchange.invalid_choice_error import InvalidChoiceError
17+
from exchange.providers.base import MissingProviderEnvVariableError
18+
19+
20+
def create_exchange(profile: Profile, notifier: SessionNotifier) -> Exchange:
21+
try:
22+
return build_exchange(profile, notifier=notifier)
23+
except InvalidChoiceError as e:
24+
error_message = (
25+
f"[bold red]{e.message}[/bold red].\nPlease check your configuration file at {PROFILES_CONFIG_PATH}.\n"
26+
+ "Configuration doc: https://block-open-source.github.io/goose/configuration.html"
27+
)
28+
print(error_message)
29+
sys.exit(1)
30+
except MissingProviderEnvVariableError as e:
31+
api_key = _get_api_key_from_keychain(e.env_variable, e.provider)
32+
if api_key is None or api_key == "":
33+
error_message = f"{e.message}. Please set the required environment variable to continue."
34+
print(Panel(error_message, style="red"))
35+
sys.exit(1)
36+
else:
37+
os.environ[e.env_variable] = api_key
38+
return build_exchange(profile=profile, notifier=notifier)
39+
40+
41+
def _get_api_key_from_keychain(env_variable: str, provider: str) -> Optional[str]:
42+
api_key = keyring.get_password("goose", env_variable)
43+
if api_key is not None:
44+
print(f"Using {env_variable} value for {provider} from your keychain")
45+
else:
46+
api_key = prompt(f"Enter {env_variable} value for {provider}:".strip())
47+
if api_key is not None and len(api_key) > 0:
48+
save_to_keyring = confirm(f"Would you like to save the {env_variable} value to your keychain?")
49+
if save_to_keyring:
50+
keyring.set_password("goose", env_variable, api_key)
51+
print(f"Saved {env_variable} to your key_chain. service_name: goose, user_name: {env_variable}")
52+
return api_key

tests/cli/test_session.py

+1-34
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
import pytest
44
from exchange import Exchange, Message, ToolUse, ToolResult
5-
from exchange.providers.base import MissingProviderEnvVariableError
6-
from exchange.invalid_choice_error import InvalidChoiceError
75
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
86
from goose.cli.prompt.user_input import PromptAction, UserInput
97
from goose.cli.session import Session
@@ -22,7 +20,7 @@ def mock_specified_session_name():
2220
@pytest.fixture
2321
def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profile_factory):
2422
with (
25-
patch("goose.cli.session.build_exchange") as mock_exchange,
23+
patch("goose.cli.session.create_exchange") as mock_exchange,
2624
patch("goose.cli.session.load_profile", return_value=profile_factory()),
2725
patch("goose.cli.session.SessionNotifier") as mock_session_notifier,
2826
patch("goose.cli.session.load_provider", return_value="provider"),
@@ -158,34 +156,3 @@ def test_set_generated_session_name(create_session_with_mock_configs, mock_sessi
158156
with patch("goose.cli.session.droid", return_value=generated_session_name):
159157
session = create_session_with_mock_configs({"name": None})
160158
assert session.name == generated_session_name
161-
162-
163-
def test_create_exchange_exit_when_env_var_does_not_exist(create_session_with_mock_configs, mock_sessions_path):
164-
session = create_session_with_mock_configs()
165-
expected_error = MissingProviderEnvVariableError(env_variable="OPENAI_API_KEY", provider="openai")
166-
with (
167-
patch("goose.cli.session.build_exchange", side_effect=expected_error),
168-
patch("goose.cli.session.print") as mock_print,
169-
patch("sys.exit") as mock_exit,
170-
):
171-
session._create_exchange()
172-
mock_print.call_args_list[0][0][0].renderable == (
173-
"Missing environment variable OPENAI_API_KEY for provider openai. ",
174-
"Please set the required environment variable to continue.",
175-
)
176-
mock_exit.assert_called_once_with(1)
177-
178-
179-
def test_create_exchange_exit_when_configuration_is_incorrect(create_session_with_mock_configs, mock_sessions_path):
180-
session = create_session_with_mock_configs()
181-
expected_error = InvalidChoiceError(
182-
attribute_name="provider", attribute_value="wrong_provider", available_values=["openai"]
183-
)
184-
with (
185-
patch("goose.cli.session.build_exchange", side_effect=expected_error),
186-
patch("goose.cli.session.print") as mock_print,
187-
patch("sys.exit") as mock_exit,
188-
):
189-
session._create_exchange()
190-
assert "Unknown provider: wrong_provider. Available providers: openai" in mock_print.call_args_list[0][0][0]
191-
mock_exit.assert_called_once_with(1)

tests/utils/test_create_exchange.py

+151
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import os
2+
from unittest.mock import MagicMock, patch
3+
4+
from exchange.exchange import Exchange
5+
from exchange.invalid_choice_error import InvalidChoiceError
6+
from exchange.providers.base import MissingProviderEnvVariableError
7+
import pytest
8+
9+
from goose.notifier import Notifier
10+
from goose.profile import Profile
11+
from goose.utils._create_exchange import create_exchange
12+
13+
TEST_PROFILE = MagicMock(spec=Profile)
14+
TEST_EXCHANGE = MagicMock(spec=Exchange)
15+
TEST_NOTIFIER = MagicMock(spec=Notifier)
16+
17+
18+
@pytest.fixture
19+
def mock_print():
20+
with patch("goose.utils._create_exchange.print") as mock_print:
21+
yield mock_print
22+
23+
24+
@pytest.fixture
25+
def mock_prompt():
26+
with patch("goose.utils._create_exchange.prompt") as mock_prompt:
27+
yield mock_prompt
28+
29+
30+
@pytest.fixture
31+
def mock_confirm():
32+
with patch("goose.utils._create_exchange.confirm") as mock_confirm:
33+
yield mock_confirm
34+
35+
36+
@pytest.fixture
37+
def mock_sys_exit():
38+
with patch("sys.exit") as mock_exit:
39+
yield mock_exit
40+
41+
42+
@pytest.fixture
43+
def mock_keyring_get_password():
44+
with patch("keyring.get_password") as mock_get_password:
45+
yield mock_get_password
46+
47+
48+
@pytest.fixture
49+
def mock_keyring_set_password():
50+
with patch("keyring.set_password") as mock_set_password:
51+
yield mock_set_password
52+
53+
54+
def test_create_exchange_success(mock_print):
55+
with patch("goose.utils._create_exchange.build_exchange", return_value=TEST_EXCHANGE):
56+
assert create_exchange(profile=TEST_PROFILE, notifier=TEST_NOTIFIER) == TEST_EXCHANGE
57+
58+
59+
def test_create_exchange_fail_with_invalid_choice_error(mock_print, mock_sys_exit):
60+
expected_error = InvalidChoiceError(
61+
attribute_name="provider", attribute_value="wrong_provider", available_values=["openai"]
62+
)
63+
with patch("goose.utils._create_exchange.build_exchange", side_effect=expected_error):
64+
create_exchange(profile=TEST_PROFILE, notifier=TEST_NOTIFIER)
65+
66+
assert "Unknown provider: wrong_provider. Available providers: openai" in mock_print.call_args_list[0][0][0]
67+
mock_sys_exit.assert_called_once_with(1)
68+
69+
70+
class TestWhenProviderEnvVarNotFound:
71+
API_KEY_ENV_VAR = "OPENAI_API_KEY"
72+
API_KEY_ENV_VALUE = "api_key_value"
73+
PROVIDER_NAME = "openai"
74+
SERVICE_NAME = "goose"
75+
EXPECTED_ERROR = MissingProviderEnvVariableError(env_variable=API_KEY_ENV_VAR, provider=PROVIDER_NAME)
76+
77+
def test_create_exchange_get_api_key_from_keychain(
78+
self, mock_print, mock_sys_exit, mock_keyring_get_password, mock_keyring_set_password
79+
):
80+
self._clean_env()
81+
with patch("goose.utils._create_exchange.build_exchange", side_effect=[self.EXPECTED_ERROR, TEST_EXCHANGE]):
82+
mock_keyring_get_password.return_value = self.API_KEY_ENV_VALUE
83+
84+
assert create_exchange(profile=TEST_PROFILE, notifier=TEST_NOTIFIER) == TEST_EXCHANGE
85+
86+
assert os.environ[self.API_KEY_ENV_VAR] == self.API_KEY_ENV_VALUE
87+
mock_keyring_get_password.assert_called_once_with(self.SERVICE_NAME, self.API_KEY_ENV_VAR)
88+
mock_print.assert_called_once_with(
89+
f"Using {self.API_KEY_ENV_VAR} value for {self.PROVIDER_NAME} from your keychain"
90+
)
91+
mock_sys_exit.assert_not_called()
92+
mock_keyring_set_password.assert_not_called()
93+
94+
def test_create_exchange_ask_api_key_and_user_set_in_keychain(
95+
self, mock_prompt, mock_confirm, mock_sys_exit, mock_keyring_get_password, mock_keyring_set_password, mock_print
96+
):
97+
self._clean_env()
98+
with patch("goose.utils._create_exchange.build_exchange", side_effect=[self.EXPECTED_ERROR, TEST_EXCHANGE]):
99+
mock_keyring_get_password.return_value = None
100+
mock_prompt.return_value = self.API_KEY_ENV_VALUE
101+
mock_confirm.return_value = True
102+
103+
assert create_exchange(profile=TEST_NOTIFIER, notifier=TEST_NOTIFIER) == TEST_EXCHANGE
104+
105+
assert os.environ[self.API_KEY_ENV_VAR] == self.API_KEY_ENV_VALUE
106+
mock_keyring_set_password.assert_called_once_with(
107+
self.SERVICE_NAME, self.API_KEY_ENV_VAR, self.API_KEY_ENV_VALUE
108+
)
109+
mock_confirm.assert_called_once_with(
110+
f"Would you like to save the {self.API_KEY_ENV_VAR} value to your keychain?"
111+
)
112+
mock_print.assert_called_once_with(
113+
f"Saved {self.API_KEY_ENV_VAR} to your key_chain. "
114+
+ f"service_name: goose, user_name: {self.API_KEY_ENV_VAR}"
115+
)
116+
mock_sys_exit.assert_not_called()
117+
118+
def test_create_exchange_ask_api_key_and_user_not_set_in_keychain(
119+
self, mock_prompt, mock_confirm, mock_sys_exit, mock_keyring_get_password, mock_keyring_set_password
120+
):
121+
self._clean_env()
122+
with patch("goose.utils._create_exchange.build_exchange", side_effect=[self.EXPECTED_ERROR, TEST_EXCHANGE]):
123+
mock_keyring_get_password.return_value = None
124+
mock_prompt.return_value = self.API_KEY_ENV_VALUE
125+
mock_confirm.return_value = False
126+
127+
assert create_exchange(profile=TEST_NOTIFIER, notifier=TEST_NOTIFIER) == TEST_EXCHANGE
128+
129+
assert os.environ[self.API_KEY_ENV_VAR] == self.API_KEY_ENV_VALUE
130+
mock_keyring_set_password.assert_not_called()
131+
mock_sys_exit.assert_not_called()
132+
133+
def test_create_exchange_fails_when_user_not_provide_api_key(
134+
self, mock_prompt, mock_confirm, mock_sys_exit, mock_keyring_get_password, mock_print
135+
):
136+
self._clean_env()
137+
with patch("goose.utils._create_exchange.build_exchange", side_effect=self.EXPECTED_ERROR):
138+
mock_keyring_get_password.return_value = None
139+
mock_prompt.return_value = None
140+
mock_confirm.return_value = False
141+
142+
create_exchange(profile=TEST_NOTIFIER, notifier=TEST_NOTIFIER)
143+
144+
assert (
145+
"Please set the required environment variable to continue."
146+
in mock_print.call_args_list[0][0][0].renderable
147+
)
148+
mock_sys_exit.assert_called_once_with(1)
149+
150+
def _clean_env(self):
151+
os.environ.pop(self.API_KEY_ENV_VAR, None)

0 commit comments

Comments
 (0)