Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 44 additions & 15 deletions crates/goose-server/src/routes/config_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ pub struct UpdateCustomProviderRequest {
pub supports_streaming: Option<bool>,
}

#[derive(Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct MaskedSecret {
pub masked_value: String,
}

#[derive(Serialize, ToSchema)]
#[serde(untagged)]
pub enum ConfigValueResponse {
Value(Value),
MaskedValue(MaskedSecret),
}

#[utoipa::path(
post,
path = "/config/upsert",
Expand Down Expand Up @@ -134,6 +147,22 @@ pub async fn remove_config(Json(query): Json<ConfigKeyQuery>) -> Result<Json<Str
}
}

const SECRET_MASK_SHOW_LEN: usize = 8;

fn mask_secret(secret: Value) -> String {
let as_string = match secret {
Value::String(s) => s,
_ => serde_json::to_string(&secret).unwrap_or_else(|_| secret.to_string()),
};

let chars: Vec<_> = as_string.chars().collect();
let show_len = std::cmp::min(chars.len() / 2, SECRET_MASK_SHOW_LEN);
let visible: String = chars.iter().take(show_len).collect();
let mask = "*".repeat(chars.len() - show_len);

format!("{}{}", visible, mask)
}

#[utoipa::path(
post,
path = "/config/read",
Expand All @@ -143,31 +172,29 @@ pub async fn remove_config(Json(query): Json<ConfigKeyQuery>) -> Result<Json<Str
(status = 500, description = "Unable to get the configuration value"),
)
)]
pub async fn read_config(Json(query): Json<ConfigKeyQuery>) -> Result<Json<Value>, StatusCode> {
pub async fn read_config(
Json(query): Json<ConfigKeyQuery>,
) -> Result<Json<ConfigValueResponse>, StatusCode> {
if query.key == "model-limits" {
let limits = ModelConfig::get_all_model_limits();
return Ok(Json(
return Ok(Json(ConfigValueResponse::Value(
serde_json::to_value(limits).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
));
)));
}

let config = Config::global();

let response_value = match config.get(&query.key, query.is_secret) {
Ok(value) => {
if query.is_secret {
Value::Bool(true)
} else {
value
}
}
Err(ConfigError::NotFound(_)) => {
if query.is_secret {
Value::Bool(false)
ConfigValueResponse::MaskedValue(MaskedSecret {
masked_value: mask_secret(value),
})
} else {
Value::Null
ConfigValueResponse::Value(value)
}
}
Err(ConfigError::NotFound(_)) => ConfigValueResponse::Value(Value::Null),
Err(_) => {
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
Expand Down Expand Up @@ -752,10 +779,12 @@ mod tests {
.await;

assert!(result.is_ok());
let response = result.unwrap();
let response = match result.unwrap().0 {
ConfigValueResponse::Value(value) => value,
ConfigValueResponse::MaskedValue(_) => panic!("unexpected secret"),
};

let limits: Vec<goose::model::ModelLimitConfig> =
serde_json::from_value(response.0).unwrap();
let limits: Vec<goose::model::ModelLimitConfig> = serde_json::from_value(response).unwrap();
assert!(!limits.is_empty());

let gpt4_limit = limits.iter().find(|l| l.pattern == "gpt-4o");
Expand Down
91 changes: 34 additions & 57 deletions ui/desktop/src/components/settings/providers/ProviderGrid.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import React, { memo, useMemo, useCallback, useState } from 'react';
import { ProviderCard } from './subcomponents/ProviderCard';
import CardContainer from './subcomponents/CardContainer';
import { ProviderModalProvider, useProviderModal } from './modal/ProviderModalProvider';
import ProviderConfigurationModal from './modal/ProviderConfiguationModal';
import {
DeclarativeProviderConfig,
Expand Down Expand Up @@ -47,7 +46,7 @@ const CustomProviderCard = memo(function CustomProviderCard({ onClick }: { onCli
);
});

const ProviderCards = memo(function ProviderCards({
function ProviderCards({
providers,
isOnboarding,
refreshProviders,
Expand All @@ -58,14 +57,19 @@ const ProviderCards = memo(function ProviderCards({
refreshProviders?: () => void;
onProviderLaunch: (provider: ProviderDetails) => void;
}) {
const { openModal } = useProviderModal();
const [configuringProvider, setConfiguringProvider] = useState<ProviderDetails | null>(null);
const [showCustomProviderModal, setShowCustomProviderModal] = useState(false);
const [editingProvider, setEditingProvider] = useState<{
id: string;
config: DeclarativeProviderConfig;
isEditable: boolean;
} | null>(null);

const openModal = useCallback(
(provider: ProviderDetails) => setConfiguringProvider(provider),
[]
);

const configureProviderViaModal = useCallback(
async (provider: ProviderDetails) => {
if (provider.provider_type === 'Custom' || provider.provider_type === 'Declarative') {
Expand All @@ -81,22 +85,10 @@ const ProviderCards = memo(function ProviderCards({
setShowCustomProviderModal(true);
}
} else {
openModal(provider, {
onSubmit: () => {
if (refreshProviders) {
refreshProviders();
}
},
onDelete: (_values: unknown) => {
if (refreshProviders) {
refreshProviders();
}
},
formProps: {},
});
openModal(provider);
}
},
[openModal, refreshProviders]
[openModal]
);

const handleUpdateCustomProvider = useCallback(
Expand All @@ -123,20 +115,12 @@ const ProviderCards = memo(function ProviderCards({
setEditingProvider(null);
}, []);

const deleteProviderConfigViaModal = useCallback(
(provider: ProviderDetails) => {
openModal(provider, {
onDelete: (_values: unknown) => {
// Only refresh if the function is provided
if (refreshProviders) {
refreshProviders();
}
},
formProps: {},
});
},
[openModal, refreshProviders]
);
const onCloseProviderConfig = useCallback(() => {
setConfiguringProvider(null);
if (refreshProviders) {
refreshProviders();
}
}, [refreshProviders]);

const handleCreateCustomProvider = useCallback(
async (data: UpdateCustomProviderRequest) => {
Expand All @@ -160,7 +144,6 @@ const ProviderCards = memo(function ProviderCards({
key={provider.name}
provider={provider}
onConfigure={() => configureProviderViaModal(provider)}
onDelete={() => deleteProviderConfigViaModal(provider)}
onLaunch={() => onProviderLaunch(provider)}
isOnboarding={isOnboarding}
/>
Expand All @@ -171,13 +154,7 @@ const ProviderCards = memo(function ProviderCards({
);

return cards;
}, [
providers,
isOnboarding,
configureProviderViaModal,
deleteProviderConfigViaModal,
onProviderLaunch,
]);
}, [providers, isOnboarding, configureProviderViaModal, onProviderLaunch]);

const initialData = editingProvider && {
engine: editingProvider.config.engine.toLowerCase() + '_compatible',
Expand Down Expand Up @@ -206,11 +183,17 @@ const ProviderCards = memo(function ProviderCards({
/>
</DialogContent>
</Dialog>{' '}
{configuringProvider && (
<ProviderConfigurationModal
provider={configuringProvider}
onClose={onCloseProviderConfig}
/>
)}
</>
);
});
}

export default memo(function ProviderGrid({
export default function ProviderGrid({
providers,
isOnboarding,
refreshProviders,
Expand All @@ -221,20 +204,14 @@ export default memo(function ProviderGrid({
refreshProviders?: () => void;
onProviderLaunch?: (provider: ProviderDetails) => void;
}) {
// Memoize the modal provider and its children to avoid recreating on every render
const modalProviderContent = useMemo(
() => (
<ProviderModalProvider>
<ProviderCards
providers={providers}
isOnboarding={isOnboarding}
refreshProviders={refreshProviders}
onProviderLaunch={onProviderLaunch || (() => {})}
/>
<ProviderConfigurationModal />
</ProviderModalProvider>
),
[providers, isOnboarding, refreshProviders, onProviderLaunch]
return (
<GridLayout>
<ProviderCards
providers={providers}
isOnboarding={isOnboarding}
refreshProviders={refreshProviders}
onProviderLaunch={onProviderLaunch || (() => {})}
/>
</GridLayout>
);
return <GridLayout>{modalProviderContent}</GridLayout>;
});
}
Loading
Loading