diff --git a/src/stores/assetsStore.test.ts b/src/stores/assetsStore.test.ts index 7f982200184..0a70f346dd0 100644 --- a/src/stores/assetsStore.test.ts +++ b/src/stores/assetsStore.test.ts @@ -36,6 +36,41 @@ vi.mock('@/platform/distribution/types', () => ({ } })) +// Mock modelToNodeStore with proper node providers and category lookups +vi.mock('@/stores/modelToNodeStore', () => ({ + useModelToNodeStore: () => ({ + getAllNodeProviders: vi.fn((category: string) => { + const providers: Record< + string, + Array<{ nodeDef: { name: string }; key: string }> + > = { + checkpoints: [ + { nodeDef: { name: 'CheckpointLoaderSimple' }, key: 'ckpt_name' }, + { nodeDef: { name: 'ImageOnlyCheckpointLoader' }, key: 'ckpt_name' } + ], + loras: [ + { nodeDef: { name: 'LoraLoader' }, key: 'lora_name' }, + { nodeDef: { name: 'LoraLoaderModelOnly' }, key: 'lora_name' } + ], + vae: [{ nodeDef: { name: 'VAELoader' }, key: 'vae_name' }] + } + return providers[category] ?? [] + }), + getCategoryForNodeType: vi.fn((nodeType: string) => { + const nodeToCategory: Record = { + CheckpointLoaderSimple: 'checkpoints', + ImageOnlyCheckpointLoader: 'checkpoints', + LoraLoader: 'loras', + LoraLoaderModelOnly: 'loras', + VAELoader: 'vae' + } + return nodeToCategory[nodeType] + }), + getNodeProvider: vi.fn(), + registerDefaults: vi.fn() + }) +})) + // Mock TaskItemImpl vi.mock('@/stores/queueStore', () => ({ TaskItemImpl: class { @@ -546,51 +581,70 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => { expect(assets.map((a) => a.id)).toContain('new-asset') }) - it('should return cached array on subsequent getAssets calls', () => { + it('should return cached array on subsequent getAssets calls', async () => { const store = useAssetsStore() - const nodeType = 'TestLoader' + const nodeType = 'CheckpointLoaderSimple' + const assets = [createMockAsset('cache-test-1')] + + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue(assets) + await store.updateModelsForNodeType(nodeType) const firstCall = store.getAssets(nodeType) const secondCall = store.getAssets(nodeType) expect(secondCall).toBe(firstCall) + expect(firstCall).toHaveLength(1) }) }) describe('concurrent request handling', () => { - it('should discard stale request when newer request starts', async () => { + it('should short-circuit concurrent calls to prevent duplicate work', async () => { const store = useAssetsStore() const nodeType = 'CheckpointLoaderSimple' const firstBatch = Array.from({ length: 5 }, (_, i) => createMockAsset(`first-${i}`) ) - const secondBatch = Array.from({ length: 10 }, (_, i) => - createMockAsset(`second-${i}`) - ) - let resolveFirst: (value: ReturnType[]) => void - const firstPromise = new Promise[]>( - (resolve) => { - resolveFirst = resolve - } - ) - let callCount = 0 - vi.mocked(assetService.getAssetsForNodeType).mockImplementation( - async () => { - callCount++ - return callCount === 1 ? firstPromise : secondBatch - } - ) + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue(firstBatch) + // Start two concurrent requests for the same category const firstRequest = store.updateModelsForNodeType(nodeType) const secondRequest = store.updateModelsForNodeType(nodeType) - resolveFirst!(firstBatch) await Promise.all([firstRequest, secondRequest]) - expect(store.getAssets(nodeType)).toHaveLength(10) + // Second request should be short-circuited, only one API call made expect( - store.getAssets(nodeType).every((a) => a.id.startsWith('second-')) - ).toBe(true) + vi.mocked(assetService.getAssetsForNodeType) + ).toHaveBeenCalledTimes(1) + expect(store.getAssets(nodeType)).toHaveLength(5) + }) + + it('should allow new request after previous completes', async () => { + const store = useAssetsStore() + const nodeType = 'CheckpointLoaderSimple' + const firstBatch = [createMockAsset('first-1')] + const secondBatch = [ + createMockAsset('second-1'), + createMockAsset('second-2') + ] + + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce( + firstBatch + ) + await store.updateModelsForNodeType(nodeType) + expect(store.getAssets(nodeType)).toHaveLength(1) + + // After first completes, a new request should work + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce( + secondBatch + ) + store.invalidateCategory('checkpoints') + await store.updateModelsForNodeType(nodeType) + + expect(store.getAssets(nodeType)).toHaveLength(2) + expect( + vi.mocked(assetService.getAssetsForNodeType) + ).toHaveBeenCalledTimes(2) }) }) @@ -614,4 +668,87 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => { expect(loadingStates).toContain(false) }) }) + + describe('category-keyed cache', () => { + it('should share cache between node types of the same category', async () => { + const store = useAssetsStore() + const assets = [createMockAsset('shared-1'), createMockAsset('shared-2')] + + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue(assets) + + await store.updateModelsForNodeType('CheckpointLoaderSimple') + + expect(store.getAssets('CheckpointLoaderSimple')).toHaveLength(2) + expect(store.getAssets('ImageOnlyCheckpointLoader')).toHaveLength(2) + expect( + vi.mocked(assetService.getAssetsForNodeType) + ).toHaveBeenCalledTimes(1) + }) + + it('should return empty array for unknown node types', () => { + const store = useAssetsStore() + expect(store.getAssets('UnknownNodeType')).toEqual([]) + }) + + it('should not fetch for unknown node types', async () => { + const store = useAssetsStore() + await store.updateModelsForNodeType('UnknownNodeType') + expect( + vi.mocked(assetService.getAssetsForNodeType) + ).not.toHaveBeenCalled() + }) + }) + + describe('invalidateCategory', () => { + it('should clear cache for a category', async () => { + const store = useAssetsStore() + const assets = [createMockAsset('asset-1'), createMockAsset('asset-2')] + + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue(assets) + await store.updateModelsForNodeType('CheckpointLoaderSimple') + expect(store.getAssets('CheckpointLoaderSimple')).toHaveLength(2) + + store.invalidateCategory('checkpoints') + + expect(store.getAssets('CheckpointLoaderSimple')).toEqual([]) + expect(store.hasAssetKey('CheckpointLoaderSimple')).toBe(false) + }) + + it('should allow refetch after invalidation', async () => { + const store = useAssetsStore() + const initialAssets = [createMockAsset('initial-1')] + const refreshedAssets = [ + createMockAsset('refreshed-1'), + createMockAsset('refreshed-2') + ] + + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce( + initialAssets + ) + await store.updateModelsForNodeType('LoraLoader') + expect(store.getAssets('LoraLoader')).toHaveLength(1) + + store.invalidateCategory('loras') + + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce( + refreshedAssets + ) + await store.updateModelsForNodeType('LoraLoader') + + expect(store.getAssets('LoraLoader')).toHaveLength(2) + }) + + it('should invalidate tag-based caches', async () => { + const store = useAssetsStore() + const assets = [createMockAsset('tag-asset-1')] + + vi.mocked(assetService.getAssetsByTag).mockResolvedValue(assets) + await store.updateModelsForTag('models') + expect(store.getAssets('tag:models')).toHaveLength(1) + + store.invalidateCategory('tag:models') + + expect(store.getAssets('tag:models')).toEqual([]) + }) + }) }) diff --git a/src/stores/assetsStore.ts b/src/stores/assetsStore.ts index 41fb3e878be..c946764010b 100644 --- a/src/stores/assetsStore.ts +++ b/src/stores/assetsStore.ts @@ -279,20 +279,22 @@ export const useAssetsStore = defineStore('assets', () => { } /** - * Model assets cached by node type (e.g., 'CheckpointLoaderSimple', 'LoraLoader') - * Used by multiple loader nodes to avoid duplicate fetches + * Model assets cached by category (e.g., 'checkpoints', 'loras') + * Multiple node types sharing the same category share the same cache entry. + * Public API accepts nodeType for backwards compatibility but translates + * to category internally using modelToNodeStore.getCategoryForNodeType(). * Cloud-only feature - empty Maps in desktop builds */ const getModelState = () => { if (isCloud) { - const modelStateByKey = ref(new Map()) + const modelStateByCategory = ref(new Map()) const assetsArrayCache = new Map< string, { source: Map; array: AssetItem[] } >() - const pendingRequestByKey = new Map() + const pendingRequestByCategory = new Map() function createState( existingAssets?: Map @@ -306,64 +308,103 @@ export const useAssetsStore = defineStore('assets', () => { }) } - function isStale(key: string, state: ModelPaginationState): boolean { - const committed = modelStateByKey.value.get(key) - const pending = pendingRequestByKey.get(key) + function isStale(category: string, state: ModelPaginationState): boolean { + const committed = modelStateByCategory.value.get(category) + const pending = pendingRequestByCategory.get(category) return committed !== state && pending !== state } const EMPTY_ASSETS: AssetItem[] = [] + /** + * Resolve a key to a category. Handles both nodeType and tag:xxx formats. + * @param key Either a nodeType (e.g., 'CheckpointLoaderSimple') or tag key (e.g., 'tag:models') + * @returns The category or undefined if not resolvable + */ + function resolveCategory(key: string): string | undefined { + if (key.startsWith('tag:')) { + return key + } + return modelToNodeStore.getCategoryForNodeType(key) + } + + /** + * Get assets by nodeType or tag key. + * Translates nodeType to category internally for cache lookup. + * @param key Either a nodeType (e.g., 'CheckpointLoaderSimple') or tag key (e.g., 'tag:models') + */ function getAssets(key: string): AssetItem[] { - const state = modelStateByKey.value.get(key) + const category = resolveCategory(key) + if (!category) return EMPTY_ASSETS + + const state = modelStateByCategory.value.get(category) const assetsMap = state?.assets if (!assetsMap) return EMPTY_ASSETS - const cached = assetsArrayCache.get(key) + const cached = assetsArrayCache.get(category) if (cached && cached.source === assetsMap) { return cached.array } const array = Array.from(assetsMap.values()) - assetsArrayCache.set(key, { source: assetsMap, array }) + assetsArrayCache.set(category, { source: assetsMap, array }) return array } function isLoading(key: string): boolean { - return modelStateByKey.value.get(key)?.isLoading ?? false + const category = resolveCategory(key) + if (!category) return false + return modelStateByCategory.value.get(category)?.isLoading ?? false } function getError(key: string): Error | undefined { - return modelStateByKey.value.get(key)?.error + const category = resolveCategory(key) + if (!category) return undefined + return modelStateByCategory.value.get(category)?.error } function hasMore(key: string): boolean { - return modelStateByKey.value.get(key)?.hasMore ?? false + const category = resolveCategory(key) + if (!category) return false + return modelStateByCategory.value.get(category)?.hasMore ?? false } function hasAssetKey(key: string): boolean { - return modelStateByKey.value.has(key) + const category = resolveCategory(key) + if (!category) return false + return modelStateByCategory.value.has(category) } /** - * Internal helper to fetch and cache assets with a given key and fetcher. + * Internal helper to fetch and cache assets for a category. * Loads first batch immediately, then progressively loads remaining batches. * Keeps existing data visible until new data is successfully fetched. + * + * Concurrent calls for the same category are short-circuited: if a request + * is already in progress (tracked via pendingRequestByCategory), subsequent + * calls return immediately to avoid redundant work. */ - async function updateModelsForKey( - key: string, + async function updateModelsForCategory( + category: string, fetcher: (options: PaginationOptions) => Promise ): Promise { - const existingState = modelStateByKey.value.get(key) + // Short-circuit if a request for this category is already in progress + if (pendingRequestByCategory.has(category)) { + return + } + + const existingState = modelStateByCategory.value.get(category) const state = createState(existingState?.assets) const seenIds = new Set() - const hasExistingData = modelStateByKey.value.has(key) + const hasExistingData = modelStateByCategory.value.has(category) if (hasExistingData) { - pendingRequestByKey.set(key, state) + pendingRequestByCategory.set(category, state) } else { - modelStateByKey.value.set(key, state) + // Also track in pending map for initial loads to prevent concurrent calls + pendingRequestByCategory.set(category, state) + modelStateByCategory.value.set(category, state) } async function loadBatches(): Promise { @@ -374,14 +415,14 @@ export const useAssetsStore = defineStore('assets', () => { offset: state.offset }) - if (isStale(key, state)) return + if (isStale(category, state)) return const isFirstBatch = state.offset === 0 if (isFirstBatch) { - assetsArrayCache.delete(key) + assetsArrayCache.delete(category) if (hasExistingData) { - pendingRequestByKey.delete(key) - modelStateByKey.value.set(key, state) + pendingRequestByCategory.delete(category) + modelStateByCategory.value.set(category, state) } } @@ -403,13 +444,13 @@ export const useAssetsStore = defineStore('assets', () => { await new Promise((resolve) => setTimeout(resolve, 50)) } } catch (err) { - if (isStale(key, state)) return - console.error(`Error loading batch for ${key}:`, err) + if (isStale(category, state)) return + console.error(`Error loading batch for ${category}:`, err) state.error = err instanceof Error ? err : new Error(String(err)) state.hasMore = false state.isLoading = false - pendingRequestByKey.delete(key) + pendingRequestByCategory.delete(category) return } @@ -421,18 +462,25 @@ export const useAssetsStore = defineStore('assets', () => { for (const id of staleIds) { state.assets.delete(id) } - assetsArrayCache.delete(key) + assetsArrayCache.delete(category) + pendingRequestByCategory.delete(category) } await loadBatches() } /** - * Fetch and cache model assets for a specific node type + * Fetch and cache model assets for a specific node type. + * Translates nodeType to category internally - multiple node types + * sharing the same category will share the same cache entry. * @param nodeType The node type to fetch assets for (e.g., 'CheckpointLoaderSimple') */ async function updateModelsForNodeType(nodeType: string): Promise { - await updateModelsForKey(nodeType, (opts) => + const category = modelToNodeStore.getCategoryForNodeType(nodeType) + if (!category) return + + // Use category as cache key but fetch using nodeType for API compatibility + await updateModelsForCategory(category, (opts) => assetService.getAssetsForNodeType(nodeType, opts) ) } @@ -442,12 +490,23 @@ export const useAssetsStore = defineStore('assets', () => { * @param tag The tag to fetch assets for (e.g., 'models') */ async function updateModelsForTag(tag: string): Promise { - const key = `tag:${tag}` - await updateModelsForKey(key, (opts) => + const category = `tag:${tag}` + await updateModelsForCategory(category, (opts) => assetService.getAssetsByTag(tag, true, opts) ) } + /** + * Invalidate the cache for a specific category. + * Forces a refetch on next access. + * @param category The category to invalidate (e.g., 'checkpoints', 'loras') + */ + function invalidateCategory(category: string): void { + modelStateByCategory.value.delete(category) + assetsArrayCache.delete(category) + pendingRequestByCategory.delete(category) + } + /** * Optimistically update an asset in the cache * @param assetId The asset ID to update @@ -459,19 +518,22 @@ export const useAssetsStore = defineStore('assets', () => { updates: Partial, cacheKey?: string ) { - const keysToCheck = cacheKey - ? [cacheKey] - : Array.from(modelStateByKey.value.keys()) + const category = cacheKey ? resolveCategory(cacheKey) : undefined + if (cacheKey && !category) return + + const categoriesToCheck = category + ? [category] + : Array.from(modelStateByCategory.value.keys()) - for (const key of keysToCheck) { - const state = modelStateByKey.value.get(key) + for (const cat of categoriesToCheck) { + const state = modelStateByCategory.value.get(cat) if (!state?.assets) continue const existingAsset = state.assets.get(assetId) if (existingAsset) { const updatedAsset = { ...existingAsset, ...updates } state.assets.set(assetId, updatedAsset) - assetsArrayCache.delete(key) + assetsArrayCache.delete(cat) if (cacheKey) return } } @@ -554,6 +616,7 @@ export const useAssetsStore = defineStore('assets', () => { hasAssetKey, updateModelsForNodeType, updateModelsForTag, + invalidateCategory, updateAssetMetadata, updateAssetTags } @@ -567,6 +630,7 @@ export const useAssetsStore = defineStore('assets', () => { hasMore: () => false, hasAssetKey: () => false, updateModelsForNodeType: async () => {}, + invalidateCategory: () => {}, updateModelsForTag: async () => {}, updateAssetMetadata: async () => {}, updateAssetTags: async () => {} @@ -581,6 +645,7 @@ export const useAssetsStore = defineStore('assets', () => { hasAssetKey, updateModelsForNodeType, updateModelsForTag, + invalidateCategory, updateAssetMetadata, updateAssetTags } = getModelState() @@ -657,6 +722,7 @@ export const useAssetsStore = defineStore('assets', () => { // Model assets - actions updateModelsForNodeType, updateModelsForTag, + invalidateCategory, updateAssetMetadata, updateAssetTags }