Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions src/platform/assets/components/AssetBrowserModal.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,30 @@ import AssetBrowserModal from '@/platform/assets/components/AssetBrowserModal.vu
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
import { useAssetsStore } from '@/stores/assetsStore'

const mockAssetsByKey = vi.hoisted(() => new Map<string, AssetItem[]>())
const mockLoadingByKey = vi.hoisted(() => new Map<string, boolean>())

vi.mock('@/i18n', () => ({
t: (key: string, params?: Record<string, string>) =>
params ? `${key}:${JSON.stringify(params)}` : key,
d: (date: Date) => date.toLocaleDateString()
}))

vi.mock('@/stores/assetsStore', () => {
const store = {
modelAssetsByNodeType: new Map<string, AssetItem[]>(),
modelLoadingByNodeType: new Map<string, boolean>(),
updateModelsForNodeType: vi.fn(),
updateModelsForTag: vi.fn()
const getAssets = vi.fn((key: string) => mockAssetsByKey.get(key) ?? [])
const isModelLoading = vi.fn(
(key: string) => mockLoadingByKey.get(key) ?? false
)
const updateModelsForNodeType = vi.fn()
const updateModelsForTag = vi.fn()
return {
useAssetsStore: () => ({
getAssets,
isModelLoading,
updateModelsForNodeType,
updateModelsForTag
})
}
return { useAssetsStore: () => store }
})

vi.mock('@/stores/modelToNodeStore', () => ({
Expand Down Expand Up @@ -183,12 +193,10 @@ describe('AssetBrowserModal', () => {
})
}

const mockStore = useAssetsStore()

beforeEach(() => {
vi.resetAllMocks()
mockStore.modelAssetsByNodeType.clear()
mockStore.modelLoadingByNodeType.clear()
mockAssetsByKey.clear()
mockLoadingByKey.clear()
})

describe('Integration with useAssetBrowser', () => {
Expand All @@ -197,7 +205,7 @@ describe('AssetBrowserModal', () => {
createTestAsset('asset1', 'Model A', 'checkpoints'),
createTestAsset('asset2', 'Model B', 'loras')
]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)

const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
await flushPromises()
Expand All @@ -214,7 +222,7 @@ describe('AssetBrowserModal', () => {
createTestAsset('c1', 'model.safetensors', 'checkpoints'),
createTestAsset('l1', 'lora.pt', 'loras')
]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)

const wrapper = createWrapper({
nodeType: 'CheckpointLoaderSimple',
Expand All @@ -231,17 +239,18 @@ describe('AssetBrowserModal', () => {

describe('Data fetching', () => {
it('triggers store refresh for node type on mount', async () => {
const store = useAssetsStore()
createWrapper({ nodeType: 'CheckpointLoaderSimple' })
await flushPromises()

expect(mockStore.updateModelsForNodeType).toHaveBeenCalledWith(
expect(store.updateModelsForNodeType).toHaveBeenCalledWith(
'CheckpointLoaderSimple'
)
})

it('displays cached assets immediately from store', async () => {
const assets = [createTestAsset('asset1', 'Cached Model', 'checkpoints')]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)

const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })

Expand All @@ -253,15 +262,16 @@ describe('AssetBrowserModal', () => {
})

it('triggers store refresh for asset type (tag) on mount', async () => {
const store = useAssetsStore()
createWrapper({ assetType: 'models' })
await flushPromises()

expect(mockStore.updateModelsForTag).toHaveBeenCalledWith('models')
expect(store.updateModelsForTag).toHaveBeenCalledWith('models')
})

it('uses tag: prefix for cache key when assetType is provided', async () => {
const assets = [createTestAsset('asset1', 'Tagged Model', 'models')]
mockStore.modelAssetsByNodeType.set('tag:models', assets)
mockAssetsByKey.set('tag:models', assets)

const wrapper = createWrapper({ assetType: 'models' })
await flushPromises()
Expand All @@ -277,7 +287,7 @@ describe('AssetBrowserModal', () => {
describe('Asset Selection', () => {
it('emits asset-select event when asset is selected', async () => {
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)

const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
await flushPromises()
Expand All @@ -290,7 +300,7 @@ describe('AssetBrowserModal', () => {

it('executes onSelect callback when provided', async () => {
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)

const onSelect = vi.fn()
const wrapper = createWrapper({
Expand Down Expand Up @@ -333,7 +343,7 @@ describe('AssetBrowserModal', () => {
createTestAsset('asset1', 'Model A', 'checkpoints'),
createTestAsset('asset2', 'Model B', 'loras')
]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)

const wrapper = createWrapper({
nodeType: 'CheckpointLoaderSimple',
Expand Down Expand Up @@ -366,7 +376,7 @@ describe('AssetBrowserModal', () => {

it('passes computed contentTitle to BaseModalLayout when no title prop', async () => {
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)

const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
await flushPromises()
Expand Down
18 changes: 6 additions & 12 deletions src/platform/assets/components/AssetBrowserModal.vue
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,21 @@ const cacheKey = computed(() => {
})

// Read directly from store cache - reactive to any store updates
const fetchedAssets = computed(
() => assetStore.modelAssetsByNodeType.get(cacheKey.value) ?? []
)
const fetchedAssets = computed(() => assetStore.getAssets(cacheKey.value))

const isStoreLoading = computed(
() => assetStore.modelLoadingByNodeType.get(cacheKey.value) ?? false
)
const isStoreLoading = computed(() => assetStore.isModelLoading(cacheKey.value))

// Only show loading spinner when loading AND no cached data
const isLoading = computed(
() => isStoreLoading.value && fetchedAssets.value.length === 0
)

async function refreshAssets(): Promise<AssetItem[]> {
async function refreshAssets(): Promise<void> {
if (props.nodeType) {
return await assetStore.updateModelsForNodeType(props.nodeType)
}
if (props.assetType) {
return await assetStore.updateModelsForTag(props.assetType)
await assetStore.updateModelsForNodeType(props.nodeType)
} else if (props.assetType) {
await assetStore.updateModelsForTag(props.assetType)
}
return []
}

// Trigger background refresh on mount
Expand Down
10 changes: 5 additions & 5 deletions src/platform/assets/services/assetService.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ describe('assetService', () => {
const result = await assetService.getAssetModels('checkpoints')

expect(api.fetchApi).toHaveBeenCalledWith(
'/assets?include_tags=models,checkpoints&limit=500'
'/assets?include_tags=models%2Ccheckpoints&limit=500'
)
expect(result).toEqual([
expect.objectContaining({ name: 'valid.safetensors', pathIndex: 0 })
Expand Down Expand Up @@ -231,9 +231,9 @@ describe('assetService', () => {
)
expect(result).toEqual(testAssets)

// Verify API call includes correct category
// Verify API call includes correct category (comma is URL-encoded by URLSearchParams)
expect(api.fetchApi).toHaveBeenCalledWith(
'/assets?include_tags=models,checkpoints&limit=500'
'/assets?include_tags=models%2Ccheckpoints&limit=500'
)
})

Expand Down Expand Up @@ -400,7 +400,7 @@ describe('assetService', () => {
})

expect(api.fetchApi).toHaveBeenCalledWith(
'/assets?include_tags=models&limit=500&include_public=true&offset=50'
'/assets?include_tags=models&limit=500&offset=50&include_public=true'
)
expect(result).toEqual(testAssets)
})
Expand All @@ -415,7 +415,7 @@ describe('assetService', () => {
})

expect(api.fetchApi).toHaveBeenCalledWith(
'/assets?include_tags=input&limit=100&include_public=false&offset=25'
'/assets?include_tags=input&limit=100&offset=25&include_public=false'
)
expect(result).toEqual(testAssets)
})
Expand Down
62 changes: 42 additions & 20 deletions src/platform/assets/services/assetService.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { fromZodError } from 'zod-validation-error'

import { st } from '@/i18n'

import {
assetItemSchema,
assetResponseSchema,
Expand All @@ -17,6 +18,16 @@ import type {
import { api } from '@/scripts/api'
import { useModelToNodeStore } from '@/stores/modelToNodeStore'

export interface PaginationOptions {
limit?: number
offset?: number
}

interface AssetRequestOptions extends PaginationOptions {
includeTags: string[]
includePublic?: boolean
}

/**
* Maps CivitAI validation error codes to localized error messages
*/
Expand Down Expand Up @@ -77,9 +88,27 @@ function createAssetService() {
* Handles API response with consistent error handling and Zod validation
*/
async function handleAssetRequest(
url: string,
options: AssetRequestOptions,
context: string
): Promise<AssetResponse> {
const {
includeTags,
limit = DEFAULT_LIMIT,
offset,
includePublic
} = options
const queryParams = new URLSearchParams({
include_tags: includeTags.join(','),
limit: limit.toString()
})
if (offset !== undefined && offset > 0) {
queryParams.set('offset', offset.toString())
}
if (includePublic !== undefined) {
queryParams.set('include_public', includePublic ? 'true' : 'false')
}

const url = `${ASSETS_ENDPOINT}?${queryParams.toString()}`
const res = await api.fetchApi(url)
if (!res.ok) {
throw new Error(
Expand All @@ -101,7 +130,7 @@ function createAssetService() {
*/
async function getAssetModelFolders(): Promise<ModelFolder[]> {
const data = await handleAssetRequest(
`${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG}&limit=${DEFAULT_LIMIT}`,
{ includeTags: [MODELS_TAG] },
'model folders'
)

Expand Down Expand Up @@ -130,7 +159,7 @@ function createAssetService() {
*/
async function getAssetModels(folder: string): Promise<ModelFile[]> {
const data = await handleAssetRequest(
`${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG},${folder}&limit=${DEFAULT_LIMIT}`,
{ includeTags: [MODELS_TAG, folder] },
`models for ${folder}`
)

Expand Down Expand Up @@ -169,9 +198,15 @@ function createAssetService() {
* and fetching all assets with that category tag
*
* @param nodeType - The ComfyUI node type (e.g., 'CheckpointLoaderSimple')
* @param options - Pagination options
* @param options.limit - Maximum number of assets to return (default: 500)
* @param options.offset - Number of assets to skip (default: 0)
* @returns Promise<AssetItem[]> - Full asset objects with preserved metadata
*/
async function getAssetsForNodeType(nodeType: string): Promise<AssetItem[]> {
async function getAssetsForNodeType(
nodeType: string,
{ limit = DEFAULT_LIMIT, offset = 0 }: PaginationOptions = {}
): Promise<AssetItem[]> {
if (!nodeType || typeof nodeType !== 'string') {
return []
}
Expand All @@ -186,7 +221,7 @@ function createAssetService() {

// Fetch assets for this category using same API pattern as getAssetModels
const data = await handleAssetRequest(
`${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG},${category}&limit=${DEFAULT_LIMIT}`,
{ includeTags: [MODELS_TAG, category], limit, offset },
`assets for ${nodeType}`
)

Expand Down Expand Up @@ -242,23 +277,10 @@ function createAssetService() {
async function getAssetsByTag(
tag: string,
includePublic: boolean = true,
{
limit = DEFAULT_LIMIT,
offset = 0
}: { limit?: number; offset?: number } = {}
{ limit = DEFAULT_LIMIT, offset = 0 }: PaginationOptions = {}
): Promise<AssetItem[]> {
const queryParams = new URLSearchParams({
include_tags: tag,
limit: limit.toString(),
include_public: includePublic ? 'true' : 'false'
})

if (offset > 0) {
queryParams.set('offset', offset.toString())
}

const data = await handleAssetRequest(
`${ASSETS_ENDPOINT}?${queryParams.toString()}`,
{ includeTags: [tag], limit, offset, includePublic },
`assets for tag ${tag}`
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ const mockGetCategoryForNodeType = vi.fn()

vi.mock('@/stores/assetsStore', () => ({
useAssetsStore: () => ({
modelAssetsByNodeType: new Map(),
modelLoadingByNodeType: new Map(),
modelErrorByNodeType: new Map(),
getAssets: () => [],
isModelLoading: () => false,
getError: () => undefined,
updateModelsForNodeType: mockUpdateModelsForNodeType
})
}))
Expand Down
Loading
Loading