Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions src/stores/assetsStore.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,41 @@ vi.mock('@/platform/distribution/types', () => ({
}
}))

// Mock modelToNodeStore with proper node providers and category lookups
vi.mock('@/stores/modelToNodeStore', () => ({
useModelToNodeStore: () => ({
getAllNodeProviders: vi.fn((category: string) => {
const providers: Record<
string,
Array<{ nodeDef: { name: string }; key: string }>
> = {
checkpoints: [
{ nodeDef: { name: 'CheckpointLoaderSimple' }, key: 'ckpt_name' },
{ nodeDef: { name: 'ImageOnlyCheckpointLoader' }, key: 'ckpt_name' }
],
loras: [
{ nodeDef: { name: 'LoraLoader' }, key: 'lora_name' },
{ nodeDef: { name: 'LoraLoaderModelOnly' }, key: 'lora_name' }
],
vae: [{ nodeDef: { name: 'VAELoader' }, key: 'vae_name' }]
}
return providers[category] ?? []
}),
getCategoryForNodeType: vi.fn((nodeType: string) => {
const nodeToCategory: Record<string, string> = {
CheckpointLoaderSimple: 'checkpoints',
ImageOnlyCheckpointLoader: 'checkpoints',
LoraLoader: 'loras',
LoraLoaderModelOnly: 'loras',
VAELoader: 'vae'
}
return nodeToCategory[nodeType]
}),
getNodeProvider: vi.fn(),
registerDefaults: vi.fn()
})
}))

// Mock TaskItemImpl
vi.mock('@/stores/queueStore', () => ({
TaskItemImpl: class {
Expand Down Expand Up @@ -614,4 +649,87 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
expect(loadingStates).toContain(false)
})
})

describe('category-keyed cache', () => {
it('should share cache between node types of the same category', async () => {
const store = useAssetsStore()
const assets = [createMockAsset('shared-1'), createMockAsset('shared-2')]

vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue(assets)

await store.updateModelsForNodeType('CheckpointLoaderSimple')

expect(store.getAssets('CheckpointLoaderSimple')).toHaveLength(2)
expect(store.getAssets('ImageOnlyCheckpointLoader')).toHaveLength(2)
expect(
vi.mocked(assetService.getAssetsForNodeType)
).toHaveBeenCalledTimes(1)
})

it('should return empty array for unknown node types', () => {
const store = useAssetsStore()
expect(store.getAssets('UnknownNodeType')).toEqual([])
})

it('should not fetch for unknown node types', async () => {
const store = useAssetsStore()
await store.updateModelsForNodeType('UnknownNodeType')
expect(
vi.mocked(assetService.getAssetsForNodeType)
).not.toHaveBeenCalled()
})
})

describe('invalidateCategory', () => {
it('should clear cache for a category', async () => {
const store = useAssetsStore()
const assets = [createMockAsset('asset-1'), createMockAsset('asset-2')]

vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue(assets)
await store.updateModelsForNodeType('CheckpointLoaderSimple')
expect(store.getAssets('CheckpointLoaderSimple')).toHaveLength(2)

store.invalidateCategory('checkpoints')

expect(store.getAssets('CheckpointLoaderSimple')).toEqual([])
expect(store.hasAssetKey('CheckpointLoaderSimple')).toBe(false)
})

it('should allow refetch after invalidation', async () => {
const store = useAssetsStore()
const initialAssets = [createMockAsset('initial-1')]
const refreshedAssets = [
createMockAsset('refreshed-1'),
createMockAsset('refreshed-2')
]

vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce(
initialAssets
)
await store.updateModelsForNodeType('LoraLoader')
expect(store.getAssets('LoraLoader')).toHaveLength(1)

store.invalidateCategory('loras')

vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce(
refreshedAssets
)
await store.updateModelsForNodeType('LoraLoader')

expect(store.getAssets('LoraLoader')).toHaveLength(2)
})

it('should invalidate tag-based caches', async () => {
const store = useAssetsStore()
const assets = [createMockAsset('tag-asset-1')]

vi.mocked(assetService.getAssetsByTag).mockResolvedValue(assets)
await store.updateModelsForTag('models')
expect(store.getAssets('tag:models')).toHaveLength(1)

store.invalidateCategory('tag:models')

expect(store.getAssets('tag:models')).toEqual([])
})
})
})
Loading