Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/platform/assets/components/AssetBrowserModal.vue
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
:assets="filteredAssets"
:loading="isLoading"
@asset-select="handleAssetSelectAndEmit"
@asset-deleted="refreshAssets"
/>
</template>
</BaseModalLayout>
Expand Down
7 changes: 3 additions & 4 deletions src/platform/assets/components/AssetCard.vue
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
<template>
<div
v-if="!deletedLocal"
data-component-id="AssetCard"
:data-asset-id="asset.id"
:aria-labelledby="titleId"
Expand Down Expand Up @@ -139,8 +138,9 @@ const { asset, interactive } = defineProps<{
interactive?: boolean
}>()

defineEmits<{
const emit = defineEmits<{
select: [asset: AssetDisplayItem]
deleted: [asset: AssetDisplayItem]
}>()

const { t } = useI18n()
Expand All @@ -158,7 +158,6 @@ const descId = useId()

const isEditing = ref(false)
const newNameRef = ref<string>()
const deletedLocal = ref(false)

const displayName = computed(() => newNameRef.value ?? asset.name)

Expand Down Expand Up @@ -211,7 +210,7 @@ function confirmDeletion() {
})
// Give a second for the completion message
await new Promise((resolve) => setTimeout(resolve, 1_000))
deletedLocal.value = true
emit('deleted', asset)
} catch (err: unknown) {
console.error(err)
promptText.value = t('assetBrowser.deletion.failed', {
Expand Down
2 changes: 2 additions & 0 deletions src/platform/assets/components/AssetGrid.vue
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
:asset="item"
:interactive="true"
@select="$emit('assetSelect', $event)"
@deleted="$emit('assetDeleted', $event)"
/>
</template>
</VirtualGrid>
Expand All @@ -56,6 +57,7 @@ const { assets } = defineProps<{

defineEmits<{
assetSelect: [asset: AssetDisplayItem]
assetDeleted: [asset: AssetDisplayItem]
}>()

const assetsWithKey = computed(() =>
Expand Down
3 changes: 2 additions & 1 deletion src/platform/assets/composables/useUploadModelWizard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ export function useUploadModelWizard(modelTypes: Ref<ModelTypeOption[]>) {
if (selectedModelType.value) {
assetDownloadStore.trackDownload(
result.task.task_id,
selectedModelType.value
selectedModelType.value,
filename
)
}
uploadStatus.value = 'processing'
Expand Down
20 changes: 17 additions & 3 deletions src/stores/assetDownloadStore.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,29 @@ describe('useAssetDownloadStore', () => {
it('associates task with model type for completion tracking', () => {
const store = useAssetDownloadStore()

store.trackDownload('task-123', 'checkpoints')
store.trackDownload('task-123', 'checkpoints', 'model.safetensors')
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))

expect(store.completedDownloads).toHaveLength(1)
expect(store.completedDownloads[0]).toMatchObject({
expect(store.lastCompletedDownload).toMatchObject({
taskId: 'task-123',
modelType: 'checkpoints'
})
})

it('handles out-of-order messages where completed arrives before progress', () => {
const store = useAssetDownloadStore()

store.trackDownload('task-123', 'checkpoints', 'model.safetensors')

dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))

dispatch(createDownloadMessage({ status: 'running', progress: 50 }))

expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads).toHaveLength(1)
expect(store.finishedDownloads[0].status).toBe('completed')
expect(store.lastCompletedDownload?.modelType).toBe('checkpoints')
})
})

describe('stale download polling', () => {
Expand Down
54 changes: 34 additions & 20 deletions src/stores/assetDownloadStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,37 @@ export interface AssetDownload {
lastUpdate: number
assetId?: string
error?: string
modelType?: string
}

interface CompletedDownload {
taskId: string
modelType: string
timestamp: number
}

const MAX_COMPLETED_DOWNLOADS = 10
const STALE_THRESHOLD_MS = 10_000
const POLL_INTERVAL_MS = 10_000

function generateDownloadTrackingPlaceholder(
taskId: string,
modelType: string,
assetName: string
): AssetDownload {
return {
taskId,
modelType,
assetName,
bytesTotal: 0,
bytesDownloaded: 0,
progress: 0,
status: 'created',
lastUpdate: Date.now()
}
}

export const useAssetDownloadStore = defineStore('assetDownload', () => {
const downloads = ref<Map<string, AssetDownload>>(new Map())
const pendingModelTypes = new Map<string, string>()
const completedDownloads = ref<CompletedDownload[]>([])
const lastCompletedDownload = ref<CompletedDownload | null>(null)

const downloadList = computed(() => Array.from(downloads.value.values()))
const activeDownloads = computed(() =>
Expand All @@ -47,8 +62,13 @@ export const useAssetDownloadStore = defineStore('assetDownload', () => {
const hasActiveDownloads = computed(() => activeDownloads.value.length > 0)
const hasDownloads = computed(() => downloads.value.size > 0)

function trackDownload(taskId: string, modelType: string) {
pendingModelTypes.set(taskId, modelType)
function trackDownload(taskId: string, modelType: string, assetName: string) {
if (downloads.value.has(taskId)) return

downloads.value.set(
taskId,
generateDownloadTrackingPlaceholder(taskId, modelType, assetName)
)
}

function handleAssetDownload(e: CustomEvent<AssetDownloadWsMessage>) {
Expand All @@ -69,24 +89,18 @@ export const useAssetDownloadStore = defineStore('assetDownload', () => {
progress: data.progress,
status: data.status,
error: data.error,
lastUpdate: Date.now()
lastUpdate: Date.now(),
modelType: existing?.modelType
}

downloads.value.set(data.task_id, download)

if (data.status === 'completed') {
const modelType = pendingModelTypes.get(data.task_id)
if (modelType) {
const updated = [
...completedDownloads.value,
{ taskId: data.task_id, modelType, timestamp: Date.now() }
]
if (updated.length > MAX_COMPLETED_DOWNLOADS) updated.shift()
completedDownloads.value = updated
pendingModelTypes.delete(data.task_id)
if (data.status === 'completed' && download.modelType) {
lastCompletedDownload.value = {
taskId: data.task_id,
modelType: download.modelType,
timestamp: Date.now()
}
} else if (data.status === 'failed') {
pendingModelTypes.delete(data.task_id)
}
}

Expand Down Expand Up @@ -157,7 +171,7 @@ export const useAssetDownloadStore = defineStore('assetDownload', () => {
hasActiveDownloads,
hasDownloads,
downloadList,
completedDownloads,
lastCompletedDownload,
trackDownload,
clearFinishedDownloads
}
Expand Down
30 changes: 19 additions & 11 deletions src/stores/assetsStore.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { useAsyncState } from '@vueuse/core'
import { useAsyncState, whenever } from '@vueuse/core'
import { isEqual } from 'es-toolkit'
import { defineStore } from 'pinia'
import { computed, shallowReactive, ref, watch } from 'vue'
import { computed, shallowReactive, ref } from 'vue'
import {
mapInputFileToAssetItem,
mapTaskOutputToAssetItem
Expand Down Expand Up @@ -376,24 +376,32 @@ export const useAssetsStore = defineStore('assets', () => {
} = getModelState()

// Watch for completed downloads and refresh model caches
watch(
() => assetDownloadStore.completedDownloads.at(-1),
whenever(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat!

() => assetDownloadStore.lastCompletedDownload,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense

async (latestDownload) => {
if (!latestDownload) return

const { modelType } = latestDownload

const providers = modelToNodeStore
.getAllNodeProviders(modelType)
.filter((provider) => provider.nodeDef?.name)
const results = await Promise.allSettled(
providers.map((provider) =>
updateModelsForNodeType(provider.nodeDef.name).then(
() => provider.nodeDef.name
)

const nodeTypeUpdates = providers.map((provider) =>
updateModelsForNodeType(provider.nodeDef.name).then(
() => provider.nodeDef.name
)
)

// Also update by tag in case modal was opened with assetType
const tagUpdates = [
updateModelsForTag(modelType),
updateModelsForTag('models')
]

const results = await Promise.allSettled([
...nodeTypeUpdates,
...tagUpdates
])

for (const result of results) {
if (result.status === 'rejected') {
console.error(
Expand Down
Loading