Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions crates/goose/src/download_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ pub struct DownloadProgress {
pub eta_seconds: Option<u64>,
/// Error message if failed
pub error: Option<String>,
/// Whether the background download task has exited
#[serde(skip)]
pub task_exited: bool,
}

#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, PartialEq)]
Expand Down Expand Up @@ -108,8 +111,15 @@ impl DownloadManager {
.lock()
.map_err(|_| anyhow::anyhow!("Failed to acquire lock"))?;

if downloads.contains_key(&model_id) {
anyhow::bail!("Download already in progress");
if let Some(existing) = downloads.get(&model_id) {
if existing.status == DownloadStatus::Downloading {
anyhow::bail!("Download already in progress");
}
Comment thread
jh-block marked this conversation as resolved.
if existing.status == DownloadStatus::Cancelled && !existing.task_exited {
anyhow::bail!(
"Download is being cancelled; wait for it to finish before restarting"
);
}
}

downloads.insert(
Expand All @@ -123,6 +133,7 @@ impl DownloadManager {
speed_bps: None,
eta_seconds: None,
error: None,
task_exited: false,
},
);
}
Expand All @@ -148,6 +159,7 @@ impl DownloadManager {
if let Some(progress) = downloads.get_mut(&model_id_clone) {
progress.status = DownloadStatus::Completed;
progress.progress_percent = 100.0;
progress.task_exited = true;
}
}

Expand All @@ -162,8 +174,11 @@ impl DownloadManager {

if let Ok(mut downloads) = downloads.lock() {
if let Some(progress) = downloads.get_mut(&model_id_clone) {
progress.status = DownloadStatus::Failed;
if progress.status != DownloadStatus::Cancelled {
progress.status = DownloadStatus::Failed;
}
progress.error = Some(e.to_string());
progress.task_exited = true;
}
}
}
Expand All @@ -179,7 +194,10 @@ impl DownloadManager {
downloads: &DownloadMap,
model_id: &str,
) -> Result<(), anyhow::Error> {
let client = reqwest::Client::new();
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(30))
.read_timeout(std::time::Duration::from_secs(60))
.build()?;
let mut response = client.get(url).send().await?;

if !response.status().is_success() {
Expand Down Expand Up @@ -217,7 +235,7 @@ impl DownloadManager {

if should_cancel {
let _ = tokio::fs::remove_file(&partial_path).await;
return Ok(());
anyhow::bail!("Download cancelled");
}

file.write_all(&chunk).await?;
Expand Down Expand Up @@ -264,10 +282,10 @@ impl DownloadManager {
pub fn clear_completed(&self, model_id: &str) {
if let Ok(mut downloads) = self.downloads.lock() {
if let Some(progress) = downloads.get(model_id) {
if progress.status == DownloadStatus::Completed
let is_terminal = progress.status == DownloadStatus::Completed
|| progress.status == DownloadStatus::Failed
|| progress.status == DownloadStatus::Cancelled
{
|| progress.status == DownloadStatus::Cancelled;
if is_terminal && progress.task_exited {
downloads.remove(model_id);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,41 +30,33 @@ export const LocalInferenceSettings = () => {
const [settingsOpenFor, setSettingsOpenFor] = useState<string | null>(null);
const { currentModel, currentProvider, refreshCurrentModelAndProvider } = useModelAndProvider();
const downloadSectionRef = useRef<HTMLDivElement>(null);
const activePolls = useRef(new Set<string>());
const selectedModelId = currentProvider === 'local' ? currentModel : null;

const loadModels = useCallback(async () => {
const loadModels = useCallback(async (): Promise<LocalModelResponse[] | undefined> => {
try {
const response = await listLocalModels();
if (response.data) {
setModels(response.data);
response.data.forEach((model) => {
if (model.status.state === 'Downloading') {
pollDownloadProgress(model.id);
Comment thread
jh-block marked this conversation as resolved.
}
});

return response.data;
}
} catch (error) {
console.error('Failed to load models:', error);
}
}, []);

// Check for any in-progress downloads when models list changes
const detectActiveDownloads = useCallback(async () => {
for (const model of models) {
if (downloads.has(model.id)) continue;
// Check models that the API reports as downloading
if (model.status.state === 'Downloading') {
pollDownloadProgress(model.id);
}
}
return undefined;
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [models, downloads]);
}, []);

useEffect(() => {
loadModels();
}, [loadModels]);

useEffect(() => {
if (models.length > 0) {
detectActiveDownloads();
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [models]);
}, []);

const selectModel = async (modelId: string) => {
try {
Expand Down Expand Up @@ -97,6 +89,14 @@ export const LocalInferenceSettings = () => {
}, []);

const pollDownloadProgress = (modelId: string) => {
if (activePolls.current.has(modelId)) return;
activePolls.current.add(modelId);

const stopPolling = (interval: ReturnType<typeof setInterval>) => {
clearInterval(interval);
activePolls.current.delete(modelId);
};

const interval = setInterval(async () => {
try {
const response = await getLocalModelDownloadProgress({ path: { model_id: modelId } });
Expand All @@ -105,23 +105,28 @@ export const LocalInferenceSettings = () => {
setDownloads((prev) => new Map(prev).set(modelId, progress));

if (progress.status === 'completed') {
clearInterval(interval);
stopPolling(interval);
setDownloads((prev) => {
const next = new Map(prev);
next.delete(modelId);
return next;
});
await loadModels();
await selectModel(modelId);
} else if (progress.status === 'failed') {
clearInterval(interval);
} else if (progress.status === 'failed' || progress.status === 'cancelled') {
stopPolling(interval);
setDownloads((prev) => {
const next = new Map(prev);
next.delete(modelId);
return next;
});
await loadModels();
}
} else {
clearInterval(interval);
stopPolling(interval);
}
} catch {
clearInterval(interval);
stopPolling(interval);
}
}, 1000);
};
Expand All @@ -134,6 +139,7 @@ export const LocalInferenceSettings = () => {
next.delete(modelId);
return next;
});
await loadModels();
} catch (error) {
console.error('Failed to cancel download:', error);
}
Expand All @@ -143,7 +149,16 @@ export const LocalInferenceSettings = () => {
if (!window.confirm('Delete this model? You can re-download it later.')) return;
try {
await deleteLocalModel({ path: { model_id: modelId } });
await loadModels();
const updatedModels = await loadModels();

if (selectedModelId === modelId && updatedModels) {
const remainingDownloaded = updatedModels.filter(
(m) => m.id !== modelId && m.status.state === 'Downloaded'
);
if (remainingDownloaded.length > 0) {
selectModel(remainingDownloaded[0].id);
}
}
} catch (error) {
console.error('Failed to delete model:', error);
}
Expand Down
Loading