Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions crates/goose-cli/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,7 @@ async fn handle_term_subcommand(command: TermCommand) -> Result<()> {
async fn handle_local_models_command(command: LocalModelsCommand) -> Result<()> {
use goose::providers::local_inference::hf_models;
use goose::providers::local_inference::local_model_registry::{
display_name_from_repo, get_registry, model_id_from_repo, LocalModelEntry,
get_registry, model_id_from_repo, LocalModelEntry,
};

match command {
Expand Down Expand Up @@ -1482,13 +1482,12 @@ async fn handle_local_models_command(command: LocalModelsCommand) -> Result<()>
println!("Resolving {}...", spec);
let (repo_id, file) = hf_models::resolve_model_spec(&spec).await?;
let model_id = model_id_from_repo(&repo_id, &file.quantization);
let display_name = display_name_from_repo(&repo_id, &file.quantization);
let local_path =
goose::config::paths::Paths::in_data_dir("models").join(&file.filename);

println!(
"Downloading {} ({})...",
display_name,
model_id,
if file.size_bytes > 0 {
format!(
"{:.1}GB",
Expand All @@ -1502,7 +1501,6 @@ async fn handle_local_models_command(command: LocalModelsCommand) -> Result<()>
// Register
let entry = LocalModelEntry {
id: model_id.clone(),
display_name: display_name.clone(),
repo_id: repo_id.clone(),
filename: file.filename.clone(),
quantization: file.quantization.clone(),
Expand Down Expand Up @@ -1545,7 +1543,7 @@ async fn handle_local_models_command(command: LocalModelsCommand) -> Result<()>
std::io::stdout().flush().ok();
}
goose::download_manager::DownloadStatus::Completed => {
println!("\nDownloaded: {} (id: {})", display_name, model_id);
println!("\nDownloaded: {}", model_id);
break;
}
goose::download_manager::DownloadStatus::Failed => {
Expand All @@ -1572,13 +1570,12 @@ async fn handle_local_models_command(command: LocalModelsCommand) -> Result<()>
return Ok(());
}

println!("{:<40} {:<20} {:<10} Downloaded", "ID", "Name", "Quant");
println!("{}", "-".repeat(80));
println!("{:<50} {:<10} Downloaded", "ID", "Quant");
println!("{}", "-".repeat(70));
for m in models {
println!(
"{:<40} {:<20} {:<10} {}",
"{:<50} {:<10} {}",
m.id,
m.display_name,
m.quantization,
if m.is_downloaded() { "✓" } else { "✗" }
);
Expand Down
11 changes: 3 additions & 8 deletions crates/goose-server/src/routes/local_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ use goose::providers::local_inference::{
available_inference_memory_bytes,
hf_models::{resolve_model_spec, HfGgufFile},
local_model_registry::{
display_name_from_repo, get_registry, is_featured_model, model_id_from_repo,
LocalModelEntry, ModelDownloadStatus as RegistryDownloadStatus, ModelSettings,
FEATURED_MODELS,
get_registry, is_featured_model, model_id_from_repo, LocalModelEntry,
ModelDownloadStatus as RegistryDownloadStatus, ModelSettings, FEATURED_MODELS,
},
recommend_local_model,
};
Expand All @@ -40,7 +39,6 @@ pub enum ModelDownloadStatus {
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct LocalModelResponse {
pub id: String,
pub display_name: String,
pub repo_id: String,
pub filename: String,
pub quantization: String,
Expand Down Expand Up @@ -94,7 +92,6 @@ async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> {

entries_to_add.push(LocalModelEntry {
id: model_id,
display_name: display_name_from_repo(&repo_id, &quantization),
repo_id,
filename: hf_file.filename,
quantization,
Expand Down Expand Up @@ -158,7 +155,6 @@ pub async fn list_local_models(

models.push(LocalModelResponse {
id: entry.id.clone(),
display_name: entry.display_name.clone(),
repo_id: entry.repo_id.clone(),
filename: entry.filename.clone(),
quantization: entry.quantization.clone(),
Expand All @@ -175,7 +171,7 @@ pub async fn list_local_models(
match (b_downloaded, a_downloaded) {
(true, false) => std::cmp::Ordering::Greater,
(false, true) => std::cmp::Ordering::Less,
_ => a.display_name.cmp(&b.display_name),
_ => a.id.cmp(&b.id),
}
});

Expand Down Expand Up @@ -272,7 +268,6 @@ pub async fn download_hf_model(

let entry = LocalModelEntry {
id: model_id.clone(),
display_name: display_name_from_repo(&repo_id, &quantization),
repo_id,
filename: hf_file.filename,
quantization,
Expand Down
12 changes: 0 additions & 12 deletions crates/goose/src/providers/local_inference/local_model_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ pub fn get_registry() -> &'static Mutex<LocalModelRegistry> {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocalModelEntry {
pub id: String,
pub display_name: String,
pub repo_id: String,
pub filename: String,
pub quantization: String,
Expand Down Expand Up @@ -306,14 +305,3 @@ impl LocalModelRegistry {
pub fn model_id_from_repo(repo_id: &str, quantization: &str) -> String {
format!("{}:{}", repo_id, quantization)
}

/// Generate a display name from repo_id and quantization.
pub fn display_name_from_repo(repo_id: &str, quantization: &str) -> String {
let model_name = repo_id
.split('/')
.next_back()
.unwrap_or(repo_id)
.trim_end_matches("-GGUF")
.trim_end_matches("-gguf");
format!("{} ({})", model_name, quantization)
}
4 changes: 0 additions & 4 deletions ui/desktop/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -5185,7 +5185,6 @@
"type": "object",
"required": [
"id",
"display_name",
"repo_id",
"filename",
"quantization",
Expand All @@ -5195,9 +5194,6 @@
"settings"
],
"properties": {
"display_name": {
"type": "string"
},
"filename": {
"type": "string"
},
Expand Down
1 change: 0 additions & 1 deletion ui/desktop/src/api/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,6 @@ export type LoadedProvider = {
};

export type LocalModelResponse = {
display_name: string;
filename: string;
id: string;
quantization: string;
Expand Down
10 changes: 5 additions & 5 deletions ui/desktop/src/components/LocalModelSetup.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ export function LocalModelSetup({ onSuccess, onCancel }: LocalModelSetupProps) {
<div className="flex-1 min-w-0">
<div className="flex items-center gap-2 flex-wrap">
<span className="font-medium text-text-default text-sm sm:text-base">
{recommended.display_name}
{recommended.id}
</span>
{recommended.status.state === 'Downloaded' && (
<span className="text-xs bg-green-600 text-white px-2 py-0.5 rounded-full">
Expand Down Expand Up @@ -294,7 +294,7 @@ export function LocalModelSetup({ onSuccess, onCancel }: LocalModelSetupProps) {
/>
<div className="flex-1 min-w-0">
<div className="flex items-center gap-2 flex-wrap">
<span className="font-medium text-text-default text-sm">{model.display_name}</span>
<span className="font-medium text-text-default text-sm">{model.id}</span>
<span className="text-xs text-text-muted">{formatSize(model.size_bytes)}</span>
{model.status.state === 'Downloaded' && (
<span className="text-xs bg-green-600 text-white px-2 py-0.5 rounded-full">
Expand All @@ -318,9 +318,9 @@ export function LocalModelSetup({ onSuccess, onCancel }: LocalModelSetupProps) {
className="w-full px-6 py-3 bg-background-muted text-text-default rounded-lg transition-colors font-medium disabled:opacity-40 disabled:cursor-not-allowed hover:bg-background-muted/80"
>
{selectedModel?.status.state === 'Downloaded'
? `Use ${selectedModel.display_name}`
? `Use ${selectedModel.id}`
: selectedModel
? `Download ${selectedModel.display_name} (${formatSize(selectedModel.size_bytes)})`
? `Download ${selectedModel.id} (${formatSize(selectedModel.size_bytes)})`
: 'Select a model'}
</button>

Expand All @@ -338,7 +338,7 @@ export function LocalModelSetup({ onSuccess, onCancel }: LocalModelSetupProps) {
<div className="space-y-6">
<div className="border border-border-subtle rounded-xl p-5 sm:p-6 bg-background-default">
<p className="font-medium text-text-default text-sm sm:text-base mb-4">
Downloading {selectedModel.display_name}
Downloading {selectedModel.id}
</p>

{downloadProgress ? (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,6 @@ export const LocalInferenceSettings = () => {
const downloadSectionRef = useRef<HTMLDivElement>(null);
const selectedModelId = currentProvider === 'local' ? currentModel : null;

const getDisplayName = useCallback(
(modelId: string): string => {
const model = models.find((m) => m.id === modelId);
return model?.display_name || modelId;
},
[models]
);

const loadModels = useCallback(async () => {
try {
const response = await listLocalModels();
Expand Down Expand Up @@ -195,15 +187,14 @@ export const LocalInferenceSettings = () => {
<div className="space-y-2">
{Array.from(downloads.entries()).map(([modelId, progress]) => {
if (progress.status === 'completed') return null;
const displayName = getDisplayName(modelId);
return (
<div
key={modelId}
className="border rounded-lg p-3 border-border-subtle bg-background-default"
>
<div className="flex items-center justify-between mb-2">
<span className="text-sm font-medium text-text-default truncate">
{displayName}
{modelId}
</span>
{progress.status === 'downloading' && (
<Button
Expand Down Expand Up @@ -283,7 +274,7 @@ export const LocalInferenceSettings = () => {
className="cursor-pointer"
/>
<span className="text-sm font-medium text-text-default">
{model.display_name}
{model.id}
</span>
<span className="text-xs text-text-muted">
{formatBytes(model.size_bytes)}
Expand Down Expand Up @@ -334,7 +325,7 @@ export const LocalInferenceSettings = () => {
<div className="flex-1 min-w-0">
<div className="flex items-center gap-2 flex-wrap">
<h4 className="text-sm font-medium text-text-default">
{model.display_name}
{model.id}
</h4>
<span className="text-xs text-text-muted">
{formatBytes(model.size_bytes)}
Expand Down Expand Up @@ -400,7 +391,7 @@ export const LocalInferenceSettings = () => {
<DialogContent className="max-h-[80vh] overflow-y-auto sm:max-w-xl">
<DialogHeader>
<DialogTitle>Model Settings</DialogTitle>
<p className="text-sm text-text-muted">{getDisplayName(settingsOpenFor || '')}</p>
<p className="text-sm text-text-muted">{settingsOpenFor || ''}</p>
</DialogHeader>
{settingsOpenFor && <ModelSettingsPanel modelId={settingsOpenFor} />}
</DialogContent>
Expand Down
Loading