Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert pinia stores from options API to composition API #1330

Merged
merged 10 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 39 additions & 45 deletions src/stores/dialogStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,59 +2,53 @@
// Currently we need to bridge between legacy app code and Vue app with a Pinia store.

import { defineStore } from 'pinia'
import { type Component, markRaw, nextTick } from 'vue'

interface DialogState {
isVisible: boolean
title: string
headerComponent: Component | null
component: Component | null
// Props passing to the component
props: Record<string, any>
// Props passing to the Dialog component
dialogComponentProps: DialogComponentProps
}
import { ref, shallowRef, type Component, markRaw } from 'vue'

interface DialogComponentProps {
maximizable?: boolean
onClose?: () => void
}

export const useDialogStore = defineStore('dialog', {
state: (): DialogState => ({
isVisible: false,
title: '',
headerComponent: null,
component: null,
props: {},
dialogComponentProps: {}
}),
export const useDialogStore = defineStore('dialog', () => {
const isVisible = ref(false)
const title = ref('')
const headerComponent = shallowRef<Component | null>(null)
const component = shallowRef<Component | null>(null)
const props = ref<Record<string, any>>({})
const dialogComponentProps = ref<DialogComponentProps>({})

actions: {
showDialog(options: {
title?: string
headerComponent?: Component
component: Component
props?: Record<string, any>
dialogComponentProps?: DialogComponentProps
}) {
this.isVisible = true
nextTick(() => {
this.title = options.title ?? ''
this.headerComponent = options.headerComponent
? markRaw(options.headerComponent)
: null
this.component = markRaw(options.component)
this.props = options.props || {}
this.dialogComponentProps = options.dialogComponentProps || {}
})
},
function showDialog(options: {
title?: string
headerComponent?: Component
component: Component
props?: Record<string, any>
dialogComponentProps?: DialogComponentProps
}) {
isVisible.value = true
title.value = options.title ?? ''
headerComponent.value = options.headerComponent
? markRaw(options.headerComponent)
: null
component.value = markRaw(options.component)
props.value = options.props || {}
dialogComponentProps.value = options.dialogComponentProps || {}
}

closeDialog() {
if (this.dialogComponentProps.onClose) {
this.dialogComponentProps.onClose()
}
this.isVisible = false
function closeDialog() {
if (dialogComponentProps.value.onClose) {
dialogComponentProps.value.onClose()
}
isVisible.value = false
}

return {
isVisible,
title,
headerComponent,
component,
props,
dialogComponentProps,
showDialog,
closeDialog
}
})
81 changes: 47 additions & 34 deletions src/stores/modelStore.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { api } from '@/scripts/api'
import { ref } from 'vue'
import { defineStore } from 'pinia'
import { api } from '@/scripts/api'

/** (Internal helper) finds a value in a metadata object from any of a list of keys. */
function _findInMetadata(metadata: any, ...keys: string[]): string | null {
Expand Down Expand Up @@ -158,39 +159,51 @@ export class ModelFolder {
const folderBlacklist = ['configs', 'custom_nodes']

/** Model store handler, wraps individual per-folder model stores */
export const useModelStore = defineStore('modelStore', {
state: () => ({
modelStoreMap: {} as Record<string, ModelFolder | null>,
isLoading: {} as Record<string, Promise<ModelFolder | null> | null>,
modelFolders: [] as string[]
}),
actions: {
async getModelsInFolderCached(folder: string): Promise<ModelFolder | null> {
if (folder in this.modelStoreMap) {
return this.modelStoreMap[folder]
}
if (this.isLoading[folder]) {
return this.isLoading[folder]
}
const promise = api.getModels(folder).then((models) => {
if (!models) {
return null
}
const store = new ModelFolder(folder, models)
this.modelStoreMap[folder] = store
this.isLoading[folder] = null
return store
})
this.isLoading[folder] = promise
return promise
},
clearCache() {
this.modelStoreMap = {}
},
async getModelFolders() {
this.modelFolders = (await api.getModelFolders()).filter(
(folder) => !folderBlacklist.includes(folder)
)
export const useModelStore = defineStore('modelStore', () => {
const modelStoreMap = ref<Record<string, ModelFolder | null>>({})
const isLoading = ref<Record<string, Promise<ModelFolder | null> | null>>({})
const modelFolders = ref<string[]>([])

async function getModelsInFolderCached(
folder: string
): Promise<ModelFolder | null> {
if (folder in modelStoreMap.value) {
return modelStoreMap.value[folder]
}
if (isLoading.value[folder]) {
return isLoading.value[folder]
}
const promise = api.getModels(folder).then((models) => {
if (!models) {
return null
}
const store = new ModelFolder(folder, models)
modelStoreMap.value[folder] = store
isLoading.value[folder] = null
return store
})
isLoading.value[folder] = promise
return promise
}

function clearCache() {
Object.keys(modelStoreMap.value).forEach((key) => {
delete modelStoreMap.value[key]
})
}

async function getModelFolders() {
modelFolders.value = (await api.getModelFolders()).filter(
(folder) => !folderBlacklist.includes(folder)
)
}

return {
modelStoreMap,
isLoading,
modelFolders,
getModelsInFolderCached,
clearCache,
getModelFolders
}
})
142 changes: 73 additions & 69 deletions src/stores/modelToNodeStore.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { ComfyNodeDefImpl } from '@/stores/nodeDefStore'
import { useNodeDefStore } from '@/stores/nodeDefStore'
import { ref } from 'vue'
import { defineStore } from 'pinia'
import { ComfyNodeDefImpl, useNodeDefStore } from '@/stores/nodeDefStore'

/** Helper class that defines how to construct a node from a model. */
export class ModelNodeProvider {
Expand All @@ -17,75 +17,79 @@ export class ModelNodeProvider {
}

/** Service for mapping model types (by folder name) to nodes. */
export const useModelToNodeStore = defineStore('modelToNode', {
state: () => ({
modelToNodeMap: {} as Record<string, ModelNodeProvider[]>,
nodeDefStore: useNodeDefStore(),
haveDefaultsLoaded: false
}),
actions: {
/**
* Get the node provider for the given model type name.
* @param modelType The name of the model type to get the node provider for.
* @returns The node provider for the given model type name.
*/
getNodeProvider(modelType: string): ModelNodeProvider {
this.registerDefaults()
return this.modelToNodeMap[modelType]?.[0]
},

/**
* Get the list of all valid node providers for the given model type name.
* @param modelType The name of the model type to get the node providers for.
* @returns The list of all valid node providers for the given model type name.
*/
getAllNodeProviders(modelType: string): ModelNodeProvider[] {
this.registerDefaults()
return this.modelToNodeMap[modelType] ?? []
},
export const useModelToNodeStore = defineStore('modelToNode', () => {
const modelToNodeMap = ref<Record<string, ModelNodeProvider[]>>({})
const nodeDefStore = useNodeDefStore()
const haveDefaultsLoaded = ref(false)
/**
* Get the node provider for the given model type name.
* @param modelType The name of the model type to get the node provider for.
* @returns The node provider for the given model type name.
*/
function getNodeProvider(modelType: string): ModelNodeProvider | undefined {
registerDefaults()
return modelToNodeMap.value[modelType]?.[0]
}
/**
* Get the list of all valid node providers for the given model type name.
* @param modelType The name of the model type to get the node providers for.
* @returns The list of all valid node providers for the given model type name.
*/
function getAllNodeProviders(modelType: string): ModelNodeProvider[] {
registerDefaults()
return modelToNodeMap.value[modelType] ?? []
}
/**
* Register a node provider for the given model type name.
* @param modelType The name of the model type to register the node provider for.
* @param nodeProvider The node provider to register.
*/
function registerNodeProvider(
modelType: string,
nodeProvider: ModelNodeProvider
) {
registerDefaults()
if (!modelToNodeMap.value[modelType]) {
modelToNodeMap.value[modelType] = []
}
modelToNodeMap.value[modelType].push(nodeProvider)
}
/**
* Register a node provider for the given simple names.
* @param modelType The name of the model type to register the node provider for.
* @param nodeClass The node class name to register.
* @param key The key to use for the node input.
*/
function quickRegister(modelType: string, nodeClass: string, key: string) {
registerNodeProvider(
modelType,
new ModelNodeProvider(nodeDefStore.nodeDefsByName[nodeClass], key)
)
}

/**
* Register a node provider for the given model type name.
* @param modelType The name of the model type to register the node provider for.
* @param nodeProvider The node provider to register.
*/
registerNodeProvider(modelType: string, nodeProvider: ModelNodeProvider) {
this.registerDefaults()
this.modelToNodeMap[modelType] ??= []
this.modelToNodeMap[modelType].push(nodeProvider)
},
function registerDefaults() {
if (haveDefaultsLoaded.value) {
return
}
if (Object.keys(nodeDefStore.nodeDefsByName).length === 0) {
return
}
haveDefaultsLoaded.value = true

/**
* Register a node provider for the given simple names.
* @param modelType The name of the model type to register the node provider for.
* @param nodeClass The node class name to register.
* @param key The key to use for the node input.
*/
quickRegister(modelType: string, nodeClass: string, key: string) {
this.registerNodeProvider(
modelType,
new ModelNodeProvider(this.nodeDefStore.nodeDefsByName[nodeClass], key)
)
},
quickRegister('checkpoints', 'CheckpointLoaderSimple', 'ckpt_name')
quickRegister('checkpoints', 'ImageOnlyCheckpointLoader', 'ckpt_name')
quickRegister('loras', 'LoraLoader', 'lora_name')
quickRegister('loras', 'LoraLoaderModelOnly', 'lora_name')
quickRegister('vae', 'VAELoader', 'vae_name')
quickRegister('controlnet', 'ControlNetLoader', 'control_net_name')
}

registerDefaults() {
if (this.haveDefaultsLoaded) {
return
}
if (Object.keys(this.nodeDefStore.nodeDefsByName).length === 0) {
return
}
this.haveDefaultsLoaded = true
this.quickRegister('checkpoints', 'CheckpointLoaderSimple', 'ckpt_name')
this.quickRegister(
'checkpoints',
'ImageOnlyCheckpointLoader',
'ckpt_name'
)
this.quickRegister('loras', 'LoraLoader', 'lora_name')
this.quickRegister('loras', 'LoraLoaderModelOnly', 'lora_name')
this.quickRegister('vae', 'VAELoader', 'vae_name')
this.quickRegister('controlnet', 'ControlNetLoader', 'control_net_name')
}
return {
modelToNodeMap,
getNodeProvider,
getAllNodeProviders,
registerNodeProvider,
quickRegister,
registerDefaults
}
})
Loading
Loading