Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 160 additions & 23 deletions src/stores/assetsStore.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,41 @@ vi.mock('@/platform/distribution/types', () => ({
}
}))

// Mock modelToNodeStore with proper node providers and category lookups
vi.mock('@/stores/modelToNodeStore', () => ({
useModelToNodeStore: () => ({
getAllNodeProviders: vi.fn((category: string) => {
const providers: Record<
string,
Array<{ nodeDef: { name: string }; key: string }>
> = {
checkpoints: [
{ nodeDef: { name: 'CheckpointLoaderSimple' }, key: 'ckpt_name' },
{ nodeDef: { name: 'ImageOnlyCheckpointLoader' }, key: 'ckpt_name' }
],
loras: [
{ nodeDef: { name: 'LoraLoader' }, key: 'lora_name' },
{ nodeDef: { name: 'LoraLoaderModelOnly' }, key: 'lora_name' }
],
vae: [{ nodeDef: { name: 'VAELoader' }, key: 'vae_name' }]
}
return providers[category] ?? []
}),
getCategoryForNodeType: vi.fn((nodeType: string) => {
const nodeToCategory: Record<string, string> = {
CheckpointLoaderSimple: 'checkpoints',
ImageOnlyCheckpointLoader: 'checkpoints',
LoraLoader: 'loras',
LoraLoaderModelOnly: 'loras',
VAELoader: 'vae'
}
return nodeToCategory[nodeType]
}),
getNodeProvider: vi.fn(),
registerDefaults: vi.fn()
})
}))

// Mock TaskItemImpl
vi.mock('@/stores/queueStore', () => ({
TaskItemImpl: class {
Expand Down Expand Up @@ -546,51 +581,70 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
expect(assets.map((a) => a.id)).toContain('new-asset')
})

it('should return cached array on subsequent getAssets calls', () => {
it('should return cached array on subsequent getAssets calls', async () => {
const store = useAssetsStore()
const nodeType = 'TestLoader'
const nodeType = 'CheckpointLoaderSimple'
const assets = [createMockAsset('cache-test-1')]

vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue(assets)
await store.updateModelsForNodeType(nodeType)

const firstCall = store.getAssets(nodeType)
const secondCall = store.getAssets(nodeType)

expect(secondCall).toBe(firstCall)
expect(firstCall).toHaveLength(1)
})
})

describe('concurrent request handling', () => {
it('should discard stale request when newer request starts', async () => {
it('should short-circuit concurrent calls to prevent duplicate work', async () => {
const store = useAssetsStore()
const nodeType = 'CheckpointLoaderSimple'
const firstBatch = Array.from({ length: 5 }, (_, i) =>
createMockAsset(`first-${i}`)
)
const secondBatch = Array.from({ length: 10 }, (_, i) =>
createMockAsset(`second-${i}`)
)

let resolveFirst: (value: ReturnType<typeof createMockAsset>[]) => void
const firstPromise = new Promise<ReturnType<typeof createMockAsset>[]>(
(resolve) => {
resolveFirst = resolve
}
)
let callCount = 0
vi.mocked(assetService.getAssetsForNodeType).mockImplementation(
async () => {
callCount++
return callCount === 1 ? firstPromise : secondBatch
}
)
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue(firstBatch)

// Start two concurrent requests for the same category
const firstRequest = store.updateModelsForNodeType(nodeType)
const secondRequest = store.updateModelsForNodeType(nodeType)
resolveFirst!(firstBatch)
await Promise.all([firstRequest, secondRequest])

expect(store.getAssets(nodeType)).toHaveLength(10)
// Second request should be short-circuited, only one API call made
expect(
store.getAssets(nodeType).every((a) => a.id.startsWith('second-'))
).toBe(true)
vi.mocked(assetService.getAssetsForNodeType)
).toHaveBeenCalledTimes(1)
expect(store.getAssets(nodeType)).toHaveLength(5)
})

it('should allow new request after previous completes', async () => {
const store = useAssetsStore()
const nodeType = 'CheckpointLoaderSimple'
const firstBatch = [createMockAsset('first-1')]
const secondBatch = [
createMockAsset('second-1'),
createMockAsset('second-2')
]

vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce(
firstBatch
)
await store.updateModelsForNodeType(nodeType)
expect(store.getAssets(nodeType)).toHaveLength(1)

// After first completes, a new request should work
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce(
secondBatch
)
store.invalidateCategory('checkpoints')
await store.updateModelsForNodeType(nodeType)

expect(store.getAssets(nodeType)).toHaveLength(2)
expect(
vi.mocked(assetService.getAssetsForNodeType)
).toHaveBeenCalledTimes(2)
})
})

Expand All @@ -614,4 +668,87 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
expect(loadingStates).toContain(false)
})
})

describe('category-keyed cache', () => {
it('should share cache between node types of the same category', async () => {
const store = useAssetsStore()
const assets = [createMockAsset('shared-1'), createMockAsset('shared-2')]

vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue(assets)

await store.updateModelsForNodeType('CheckpointLoaderSimple')

expect(store.getAssets('CheckpointLoaderSimple')).toHaveLength(2)
expect(store.getAssets('ImageOnlyCheckpointLoader')).toHaveLength(2)
expect(
vi.mocked(assetService.getAssetsForNodeType)
).toHaveBeenCalledTimes(1)
})

it('should return empty array for unknown node types', () => {
const store = useAssetsStore()
expect(store.getAssets('UnknownNodeType')).toEqual([])
})

it('should not fetch for unknown node types', async () => {
const store = useAssetsStore()
await store.updateModelsForNodeType('UnknownNodeType')
expect(
vi.mocked(assetService.getAssetsForNodeType)
).not.toHaveBeenCalled()
})
})

describe('invalidateCategory', () => {
it('should clear cache for a category', async () => {
const store = useAssetsStore()
const assets = [createMockAsset('asset-1'), createMockAsset('asset-2')]

vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue(assets)
await store.updateModelsForNodeType('CheckpointLoaderSimple')
expect(store.getAssets('CheckpointLoaderSimple')).toHaveLength(2)

store.invalidateCategory('checkpoints')

expect(store.getAssets('CheckpointLoaderSimple')).toEqual([])
expect(store.hasAssetKey('CheckpointLoaderSimple')).toBe(false)
})

it('should allow refetch after invalidation', async () => {
const store = useAssetsStore()
const initialAssets = [createMockAsset('initial-1')]
const refreshedAssets = [
createMockAsset('refreshed-1'),
createMockAsset('refreshed-2')
]

vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce(
initialAssets
)
await store.updateModelsForNodeType('LoraLoader')
expect(store.getAssets('LoraLoader')).toHaveLength(1)

store.invalidateCategory('loras')

vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce(
refreshedAssets
)
await store.updateModelsForNodeType('LoraLoader')

expect(store.getAssets('LoraLoader')).toHaveLength(2)
})

it('should invalidate tag-based caches', async () => {
const store = useAssetsStore()
const assets = [createMockAsset('tag-asset-1')]

vi.mocked(assetService.getAssetsByTag).mockResolvedValue(assets)
await store.updateModelsForTag('models')
expect(store.getAssets('tag:models')).toHaveLength(1)

store.invalidateCategory('tag:models')

expect(store.getAssets('tag:models')).toEqual([])
})
})
})
Loading