Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ The project uses **Nx** for build orchestration and task management

## Testing Guidelines

See @docs/testing/*.md for detailed patterns.

- Frameworks:
- Vitest (unit/component, happy-dom)
- Playwright (E2E)
Expand Down
1 change: 1 addition & 0 deletions src/components/honeyToast/HoneyToast.stories.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ function createMockJob(overrides: Partial<AssetDownload> = {}): AssetDownload {
bytesDownloaded: 0,
progress: 0,
status: 'created',
lastUpdate: Date.now(),
...overrides
}
}
Expand Down
1 change: 1 addition & 0 deletions src/components/toast/ProgressToastItem.stories.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ function createMockJob(overrides: Partial<AssetDownload> = {}): AssetDownload {
bytesDownloaded: 0,
progress: 0,
status: 'created',
lastUpdate: Date.now(),
...overrides
}
}
Expand Down
70 changes: 70 additions & 0 deletions src/platform/tasks/services/taskService.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/**
* Task Service for polling background task status.
*
* CAVEAT: The `payload` and `result` schemas below are specific to
* `task:download_file` tasks. Other task types may have different
* payload/result structures. We are not generalizing this until
* additional use cases arise.
*/
import { z } from 'zod'
import { fromZodError } from 'zod-validation-error'

import { api } from '@/scripts/api'

const TASKS_ENDPOINT = '/tasks'

const zTaskStatus = z.enum(['created', 'running', 'completed', 'failed'])

const zDownloadFileResult = z.object({
success: z.boolean(),
file_path: z.string().optional(),
bytes_downloaded: z.number().optional(),
content_type: z.string().optional(),
hash: z.string().optional(),
filename: z.string().optional(),
asset_id: z.string().optional(),
metadata: z.record(z.unknown()).optional(),
error: z.string().optional()
})

const zTaskResponse = z.object({
id: z.string().uuid(),
idempotency_key: z.string(),
task_name: z.string(),
payload: z.record(z.unknown()),
status: zTaskStatus,
result: zDownloadFileResult.optional(),
error_message: z.string().optional(),
create_time: z.string().datetime(),
update_time: z.string().datetime(),
started_at: z.string().datetime().optional(),
completed_at: z.string().datetime().optional()
})

export type TaskResponse = z.infer<typeof zTaskResponse>

function createTaskService() {
async function getTask(taskId: string): Promise<TaskResponse> {
const res = await api.fetchApi(`${TASKS_ENDPOINT}/${taskId}`)

if (!res.ok) {
if (res.status === 404) {
throw new Error(`Task not found: ${taskId}`)
}
throw new Error(`Failed to get task ${taskId}: ${res.status}`)
}

const data = await res.json()
const result = zTaskResponse.safeParse(data)

if (!result.success) {
throw new Error(fromZodError(result.error).message)
}

return result.data
}

return { getTask }
}

export const taskService = createTaskService()
2 changes: 1 addition & 1 deletion src/schemas/apiSchema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ const zFeatureFlagsWsMessage = z.record(z.string(), z.any())

const zAssetDownloadWsMessage = z.object({
task_id: z.string(),
asset_id: z.string(),
asset_name: z.string(),
bytes_total: z.number(),
bytes_downloaded: z.number(),
progress: z.number(),
status: z.enum(['created', 'running', 'completed', 'failed']),
asset_id: z.string().optional(),
error: z.string().optional()
})

Expand Down
225 changes: 225 additions & 0 deletions src/stores/assetDownloadStore.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import { createTestingPinia } from '@pinia/testing'
import { setActivePinia } from 'pinia'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'

import type { TaskResponse } from '@/platform/tasks/services/taskService'
import { taskService } from '@/platform/tasks/services/taskService'
import type { AssetDownloadWsMessage } from '@/schemas/apiSchema'
import { useAssetDownloadStore } from '@/stores/assetDownloadStore'

type DownloadEventHandler = (e: CustomEvent<AssetDownloadWsMessage>) => void

const eventHandler = vi.hoisted(() => {
const state: { current: DownloadEventHandler | null } = { current: null }
return state
})

vi.mock('@/scripts/api', () => ({
api: {
addEventListener: vi.fn((_event: string, handler: DownloadEventHandler) => {
eventHandler.current = handler
}),
removeEventListener: vi.fn()
}
}))

vi.mock('@/platform/tasks/services/taskService', () => ({
taskService: {
getTask: vi.fn()
}
}))

function createDownloadMessage(
overrides: Partial<AssetDownloadWsMessage> = {}
): AssetDownloadWsMessage {
return {
task_id: 'task-123',
asset_id: 'asset-456',
asset_name: 'model.safetensors',
bytes_total: 1000,
bytes_downloaded: 500,
progress: 50,
status: 'running',
...overrides
}
}

function dispatch(msg: AssetDownloadWsMessage) {
if (!eventHandler.current) {
throw new Error(
'Event handler not registered. Call useAssetDownloadStore() first.'
)
}
eventHandler.current(new CustomEvent('asset_download', { detail: msg }))
}

describe('useAssetDownloadStore', () => {
beforeEach(() => {
setActivePinia(createTestingPinia({ stubActions: false }))
vi.useFakeTimers({ shouldAdvanceTime: false })
vi.resetAllMocks()
eventHandler.current = null
})

afterEach(() => {
vi.useRealTimers()
})

describe('handleAssetDownload', () => {
it('tracks running downloads', () => {
const store = useAssetDownloadStore()

dispatch(createDownloadMessage())

expect(store.activeDownloads).toHaveLength(1)
expect(store.activeDownloads[0].taskId).toBe('task-123')
expect(store.activeDownloads[0].progress).toBe(50)
})

it('moves download to finished when completed', () => {
const store = useAssetDownloadStore()

dispatch(createDownloadMessage({ status: 'running' }))
expect(store.activeDownloads).toHaveLength(1)

dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))

expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads).toHaveLength(1)
expect(store.finishedDownloads[0].status).toBe('completed')
})

it('moves download to finished when failed', () => {
const store = useAssetDownloadStore()

dispatch(createDownloadMessage({ status: 'running' }))
dispatch(
createDownloadMessage({ status: 'failed', error: 'Network error' })
)

expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads).toHaveLength(1)
expect(store.finishedDownloads[0].status).toBe('failed')
expect(store.finishedDownloads[0].error).toBe('Network error')
})

it('ignores duplicate terminal state messages', () => {
const store = useAssetDownloadStore()

dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))

expect(store.finishedDownloads).toHaveLength(1)
})
})

describe('trackDownload', () => {
it('associates task with model type for completion tracking', () => {
const store = useAssetDownloadStore()

store.trackDownload('task-123', 'checkpoints')
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))

expect(store.completedDownloads).toHaveLength(1)
expect(store.completedDownloads[0]).toMatchObject({
taskId: 'task-123',
modelType: 'checkpoints'
})
})
})

describe('stale download polling', () => {
function createTaskResponse(
overrides: Partial<TaskResponse> = {}
): TaskResponse {
return {
id: 'task-123',
idempotency_key: 'key-123',
task_name: 'task:download_file',
payload: {},
status: 'completed',
create_time: new Date().toISOString(),
update_time: new Date().toISOString(),
result: {
success: true,
asset_id: 'asset-456',
filename: 'model.safetensors',
bytes_downloaded: 1000
},
...overrides
}
}

it('polls and completes stale downloads', async () => {
const store = useAssetDownloadStore()

vi.mocked(taskService.getTask).mockResolvedValue(createTaskResponse())

dispatch(createDownloadMessage({ status: 'running' }))
expect(store.activeDownloads).toHaveLength(1)

await vi.advanceTimersByTimeAsync(45_000)

expect(taskService.getTask).toHaveBeenCalledWith('task-123')
expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads[0].status).toBe('completed')
})

it('polls and marks failed downloads', async () => {
const store = useAssetDownloadStore()

vi.mocked(taskService.getTask).mockResolvedValue(
createTaskResponse({
status: 'failed',
error_message: 'Download failed',
result: { success: false, error: 'Network error' }
})
)

dispatch(createDownloadMessage({ status: 'running' }))
await vi.advanceTimersByTimeAsync(45_000)

expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads[0].status).toBe('failed')
expect(store.finishedDownloads[0].error).toBe('Download failed')
})

it('does not complete if task still running', async () => {
const store = useAssetDownloadStore()

vi.mocked(taskService.getTask).mockResolvedValue(
createTaskResponse({ status: 'running', result: undefined })
)

dispatch(createDownloadMessage({ status: 'running' }))
await vi.advanceTimersByTimeAsync(45_000)

expect(taskService.getTask).toHaveBeenCalled()
expect(store.activeDownloads).toHaveLength(1)
})

it('continues tracking on polling error', async () => {
const store = useAssetDownloadStore()

vi.mocked(taskService.getTask).mockRejectedValue(new Error('Not found'))
dispatch(createDownloadMessage({ status: 'running' }))

await vi.advanceTimersByTimeAsync(45_000)

expect(store.activeDownloads).toHaveLength(1)
})
})

describe('clearFinishedDownloads', () => {
it('removes all finished downloads', () => {
const store = useAssetDownloadStore()

dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
expect(store.finishedDownloads).toHaveLength(1)

store.clearFinishedDownloads()

expect(store.finishedDownloads).toHaveLength(0)
})
})
})
Loading
Loading