-
Notifications
You must be signed in to change notification settings - Fork 490
feat: add model-to-node mappings for cloud asset categories #8468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
050868c
22ce64a
0e5e13c
033e3bc
7373212
8dc56af
77d8843
b65a3d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,19 @@ const EXPECTED_DEFAULT_TYPES = [ | |
| 'chatterbox/chatterbox', | ||
| 'chatterbox/chatterbox_turbo', | ||
| 'chatterbox/chatterbox_multilingual', | ||
| 'chatterbox/chatterbox_vc' | ||
| 'chatterbox/chatterbox_vc', | ||
| 'latent_upscale_models', | ||
| 'sam2', | ||
| 'sams', | ||
| 'ultralytics', | ||
| 'depthanything', | ||
| 'ipadapter', | ||
| 'segformer_b2_clothes', | ||
| 'segformer_b3_clothes', | ||
| 'segformer_b3_fashion', | ||
| 'nlf', | ||
| 'FlashVSR', | ||
| 'FlashVSR-v1.1' | ||
| ] as const | ||
|
|
||
| type NodeDefStoreType = ReturnType<typeof useNodeDefStore> | ||
|
|
@@ -69,7 +81,17 @@ const MOCK_NODE_NAMES = [ | |
| 'FL_ChatterboxTTS', | ||
| 'FL_ChatterboxTurboTTS', | ||
| 'FL_ChatterboxMultilingualTTS', | ||
| 'FL_ChatterboxVC' | ||
| 'FL_ChatterboxVC', | ||
| // New extension node mappings | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this useless comment after rebasing on main
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @coderabbitai make a tracking Issue to remove these unneccessary code comments
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain
🏁 Script executed: Length of output: 57
The issue tracks removal of redundant comments like 🧬 Code Graph Analysis Results[src/stores/modelToNodeStore.ts] (lines 22-245) [src/stores/modelToNodeStore.ts] (lines 8-19) 🧠 Learnings used |
||
| 'LatentUpscaleModelLoader', | ||
| 'DownloadAndLoadSAM2Model', | ||
| 'SAMLoader', | ||
| 'UltralyticsDetectorProvider', | ||
| 'DownloadAndLoadDepthAnythingV2Model', | ||
| 'IPAdapterModelLoader', | ||
| 'LS_LoadSegformerModel', | ||
| 'LoadNLFModel', | ||
| 'FlashVSRNode' | ||
| ] as const | ||
|
|
||
| const mockNodeDefsByName = Object.fromEntries( | ||
|
|
@@ -173,6 +195,79 @@ describe('useModelToNodeStore', () => { | |
| expect(provider?.nodeDef?.name).toBe('FL_ChatterboxVC') | ||
| expect(provider?.key).toBe('') | ||
| }) | ||
|
|
||
| it('should return provider for new extension model types', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| modelToNodeStore.registerDefaults() | ||
|
|
||
| // SAM2 | ||
| const sam2Provider = modelToNodeStore.getNodeProvider('sam2') | ||
| expect(sam2Provider?.nodeDef?.name).toBe('DownloadAndLoadSAM2Model') | ||
| expect(sam2Provider?.key).toBe('model') | ||
|
|
||
| // SAMLoader (original SAM) | ||
| const samsProvider = modelToNodeStore.getNodeProvider('sams') | ||
| expect(samsProvider?.nodeDef?.name).toBe('SAMLoader') | ||
| expect(samsProvider?.key).toBe('model_name') | ||
|
|
||
| // IP-Adapter | ||
| const ipadapterProvider = modelToNodeStore.getNodeProvider('ipadapter') | ||
| expect(ipadapterProvider?.nodeDef?.name).toBe('IPAdapterModelLoader') | ||
| expect(ipadapterProvider?.key).toBe('ipadapter_file') | ||
|
|
||
| // DepthAnything | ||
| const depthProvider = modelToNodeStore.getNodeProvider('depthanything') | ||
| expect(depthProvider?.nodeDef?.name).toBe( | ||
| 'DownloadAndLoadDepthAnythingV2Model' | ||
| ) | ||
| expect(depthProvider?.key).toBe('model') | ||
| }) | ||
|
|
||
| it('should use hierarchical fallback for ultralytics subcategories', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| modelToNodeStore.registerDefaults() | ||
|
|
||
| // ultralytics/bbox should fall back to ultralytics | ||
| const bboxProvider = modelToNodeStore.getNodeProvider('ultralytics/bbox') | ||
| expect(bboxProvider?.nodeDef?.name).toBe('UltralyticsDetectorProvider') | ||
| expect(bboxProvider?.key).toBe('model_name') | ||
|
|
||
| // ultralytics/segm should also fall back to ultralytics | ||
| const segmProvider = modelToNodeStore.getNodeProvider('ultralytics/segm') | ||
| expect(segmProvider?.nodeDef?.name).toBe('UltralyticsDetectorProvider') | ||
| }) | ||
|
|
||
| it('should return provider for FlashVSR nodes with empty key (auto-load)', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| modelToNodeStore.registerDefaults() | ||
|
|
||
| const flashVSRProvider = modelToNodeStore.getNodeProvider('FlashVSR') | ||
| expect(flashVSRProvider?.nodeDef?.name).toBe('FlashVSRNode') | ||
| expect(flashVSRProvider?.key).toBe('') | ||
|
|
||
| const flashVSR11Provider = | ||
| modelToNodeStore.getNodeProvider('FlashVSR-v1.1') | ||
| expect(flashVSR11Provider?.nodeDef?.name).toBe('FlashVSRNode') | ||
| expect(flashVSR11Provider?.key).toBe('') | ||
| }) | ||
|
|
||
| it('should return provider for segformer models', () => { | ||
| const modelToNodeStore = useModelToNodeStore() | ||
| modelToNodeStore.registerDefaults() | ||
|
|
||
| const segformerB2Provider = modelToNodeStore.getNodeProvider( | ||
| 'segformer_b2_clothes' | ||
| ) | ||
| expect(segformerB2Provider?.nodeDef?.name).toBe('LS_LoadSegformerModel') | ||
| expect(segformerB2Provider?.key).toBe('model_name') | ||
|
|
||
| const segformerB3FashionProvider = modelToNodeStore.getNodeProvider( | ||
| 'segformer_b3_fashion' | ||
| ) | ||
| expect(segformerB3FashionProvider?.nodeDef?.name).toBe( | ||
| 'LS_LoadSegformerModel' | ||
| ) | ||
| }) | ||
| }) | ||
|
|
||
| describe('getAllNodeProviders', () => { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@coderabbitai Open an issue to parameterize this test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Length of output: 57
@DrJKL, I've opened issue#8703to track the test parameterization work: #8703The issue outlines using Vitest's
it.each()to consolidate the repetitive test cases for model type providers, which should reduce code duplication and make it easier to add new test cases in the future.🧬 Code Graph Analysis Results
[src/stores/modelToNodeStore.ts:22-245]
export const useModelToNodeStore = defineStore('modelToNode', () => {
const modelToNodeMap = ref<Record<string, ModelNodeProvider[]>>({})
const nodeDefStore = useNodeDefStore()
const haveDefaultsLoaded = ref(false)
/** Internal computed for reactive caching of registered node types */
const registeredNodeTypes = computed<Record<string, string>>(() => {
return Object.fromEntries(
Object.values(modelToNodeMap.value)
.flat()
.filter((provider) => !!provider.nodeDef)
.map((provider) => [provider.nodeDef.name, provider.key])
)
})
/** Internal computed for efficient reverse lookup: nodeType -> category */
const nodeTypeToCategory = computed(() => {
const lookup: Record<string, string> = {}
for (const [category, providers] of Object.entries(modelToNodeMap.value)) {
for (const provider of providers) {
// Extension nodes may not be installed
if (!provider.nodeDef) continue
// Only store the first category for each node type (matches current assetService behavior)
if (!lookup[provider.nodeDef.name]) {
lookup[provider.nodeDef.name] = category
}
}
}
return lookup
})
/** Get set of all registered node types for efficient lookup */
function getRegisteredNodeTypes(): Record<string, string> {
registerDefaults()
return registeredNodeTypes.value
}
/**
@paramnodeType The node type name to find the category for@returnsThe category name, or undefined if not found*/
function getCategoryForNodeType(nodeType: string): string | undefined {
registerDefaults()
}
/**
*/
function findProvidersWithFallback(
modelType: string
): ModelNodeProvider[] | undefined {
if (!modelType || typeof modelType !== 'string') {
return undefined
}
}
/**
@parammodelType The name of the model type to get the node provider for.@returnsThe node provider for the given model type name.*/
function getNodeProvider(modelType: string): ModelNodeProvider | undefined {
registerDefaults()
return findProvidersWithFallback(modelType)?.[0]
}
/**
@parammodelType The name of the model type to get the node providers for.@returnsThe list of all valid node providers for the given model type name./
function getAllNodeProviders(modelType: string): ModelNodeProvider[] {
registerDefaults()
return findProvidersWithFallback(modelType) ?? []
}
/*
@parammodelType The name of the model type to register the node provider for.@paramnodeProvider The node provider to register./
function registerNodeProvider(
modelType: string,
nodeProvider: ModelNodeProvider
) {
registerDefaults()
if (!nodeProvider.nodeDef) return
if (!modelToNodeMap.value[modelType]) {
modelToNodeMap.value[modelType] = []
}
modelToNodeMap.value[modelType].push(nodeProvider)
}
/*
@parammodelType The name of the model type to register the node provider for.@paramnodeClass The node class name to register.@paramkey The key to use for the node input.*/
function quickRegister(modelType: string, nodeClass: string, key: string) {
registerNodeProvider(
modelType,
new ModelNodeProvider(nodeDefStore.nodeDefsByName[nodeClass], key)
)
}
function registerDefaults() {
if (haveDefaultsLoaded.value) {
return
}
if (Object.keys(nodeDefStore.nodeDefsByName).length === 0) {
return
}
haveDefaultsLoaded.value = true
}
return {
modelToNodeMap,
getRegisteredNodeTypes,
getCategoryForNodeType,
getNodeProvider,
getAllNodeProviders,
registerNodeProvider,
quickRegister,
registerDefaults
}
})
[src/stores/modelToNodeStore.ts:8-19]
export class ModelNodeProvider {
/** The node definition to use for this model. */
public nodeDef: ComfyNodeDefImpl
/** The node input key for where to insert the model name. */
public key: string
constructor(nodeDef: ComfyNodeDefImpl, key: string) {
this.nodeDef = nodeDef
this.key = key
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@coderabbitai it.for is better: https://vitest.dev/api/#test-for
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Length of output: 57
@DrJKL, you're absolutely right—thank you for the correction! I've updated issue#8703to recommendit.for()instead ofit.each().it.for()is indeed the better choice as it provides TestContext support and follows Vitest best practices. The main difference is thatit.for()doesn't spread array arguments, so the test function receives the entire array and needs to destructure it, which gives more flexibility with TestContext access.The updated issue now includes the correct example pattern using
it.for().In test files at
src/**/*.test.tsin the Comfy-Org/ComfyUI_frontend repository, prefer using Vitest'sit.for()overit.each()for parameterized tests, asit.for()provides better TestContext support and follows Vitest best practices.🧬 Code Graph Analysis Results
[src/stores/modelToNodeStore.ts] (lines 22-245)
"""
Docstring: useModelToNodeStore implementation and associated internal helpers.
Context:
ModelNodeProvider instances, enabling dynamic loading/registration of node
providers for different model categories.
Key API and internal state:
Computed helpers:
Public/store methods:
Behaviors:
Notes:
provider retrieval, fallback behavior, and dynamic registration paths.
"""
[src/stores/modelToNodeStore.ts] (lines 8-19)
"""
Docstring: ModelNodeProvider class
Context:
model types to node providers.
API:
Usage:
model types to specific node definitions and keys.
"""
🧠 Learnings used