diff --git a/__init__.py b/__init__.py
index 11a713d..f5c0bb9 100644
--- a/__init__.py
+++ b/__init__.py
@@ -5,11 +5,13 @@
# Project: comfy_mtb
# Author: Mel Massadian
# Copyright (c) 2023 Mel Massadian
-#
+#
###
import os
+# todo: don't override this if the user has that setup already
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
+os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
import traceback
from .log import log, blue_text, cyan_text, get_summary, get_label
@@ -112,8 +114,16 @@ def load_nodes():
elif web_extensions_root.exists():
web_tgt = here / "web"
+ src = web_tgt.as_posix()
+ dst = web_mtb.as_posix()
try:
- os.symlink(web_tgt.as_posix(), web_mtb.as_posix())
+ if os.name == "nt":
+ import _winapi
+
+ _winapi.CreateJunction(src, dst)
+ else:
+ os.symlink(web_tgt.as_posix(), web_mtb.as_posix())
+
except OSError:
log.warn(f"Failed to create symlink to {web_mtb}, trying to copy it")
try:
@@ -174,123 +184,121 @@ def load_nodes():
import logging
from .endpoint import endlog
+if hasattr(PromptServer, "instance"):
-@PromptServer.instance.routes.get("/mtb/status")
-async def get_full_library(request):
- from . import endpoint
+ @PromptServer.instance.routes.get("/mtb/status")
+ async def get_full_library(request):
+ from . import endpoint
- reload(endpoint)
+ reload(endpoint)
- endlog.debug("Getting node registration status")
- # Check if the request prefers HTML content
- if "text/html" in request.headers.get("Accept", ""):
- # # Return an HTML page
- html_response = endpoint.render_table(
- NODE_CLASS_MAPPINGS_DEBUG, title="Registered"
- )
- html_response += endpoint.render_table(
- {k: "-" for k in failed}, title="Failed to load"
+ endlog.debug("Getting node registration status")
+ # Check if the request prefers HTML content
+ if "text/html" in request.headers.get("Accept", ""):
+ # # Return an HTML page
+ html_response = endpoint.render_table(
+ NODE_CLASS_MAPPINGS_DEBUG, title="Registered"
+ )
+ html_response += endpoint.render_table(
+ {k: "-" for k in failed}, title="Failed to load"
+ )
+
+ return web.Response(
+ text=endpoint.render_base_template("MTB", html_response),
+ content_type="text/html",
+ )
+
+ return web.json_response(
+ {
+ "registered": NODE_CLASS_MAPPINGS_DEBUG,
+ "failed": failed,
+ }
)
- return web.Response(
- text=endpoint.render_base_template("MTB", html_response),
- content_type="text/html",
+ @PromptServer.instance.routes.post("/mtb/debug")
+ async def set_debug(request):
+ json_data = await request.json()
+ enabled = json_data.get("enabled")
+ if enabled:
+ os.environ["MTB_DEBUG"] = "true"
+ log.setLevel(logging.DEBUG)
+ log.debug("Debug mode set from API (/mtb/debug POST route)")
+
+ else:
+ if "MTB_DEBUG" in os.environ:
+ # del os.environ["MTB_DEBUG"]
+ os.environ.pop("MTB_DEBUG")
+ log.setLevel(logging.INFO)
+
+ return web.json_response(
+ {"message": f"Debug mode {'set' if enabled else 'unset'}"}
)
- return web.json_response(
- {
- "registered": NODE_CLASS_MAPPINGS_DEBUG,
- "failed": failed,
- }
- )
+ @PromptServer.instance.routes.get("/mtb")
+ async def get_home(request):
+ from . import endpoint
+
+ reload(endpoint)
+ # Check if the request prefers HTML content
+ if "text/html" in request.headers.get("Accept", ""):
+ # # Return an HTML page
+ html_response = f"""
+
+ """
+ return web.Response(
+ text=endpoint.render_base_template("MTB", html_response),
+ content_type="text/html",
+ )
+ # Return JSON for other requests
+ return web.json_response({"message": "Welcome to MTB!"})
-@PromptServer.instance.routes.post("/mtb/debug")
-async def set_debug(request):
- json_data = await request.json()
- enabled = json_data.get("enabled")
- if enabled:
- os.environ["MTB_DEBUG"] = "true"
- log.setLevel(logging.DEBUG)
- log.debug("Debug mode set from API (/mtb/debug POST route)")
+ @PromptServer.instance.routes.get("/mtb/debug")
+ async def get_debug(request):
+ from . import endpoint
- else:
+ reload(endpoint)
+ enabled = False
if "MTB_DEBUG" in os.environ:
- # del os.environ["MTB_DEBUG"]
- os.environ.pop("MTB_DEBUG")
- log.setLevel(logging.INFO)
-
- return web.json_response({"message": f"Debug mode {'set' if enabled else 'unset'}"})
-
-
-@PromptServer.instance.routes.get("/mtb")
-async def get_home(request):
- from . import endpoint
-
- reload(endpoint)
- # Check if the request prefers HTML content
- if "text/html" in request.headers.get("Accept", ""):
- # # Return an HTML page
- html_response = f"""
-
- """
- return web.Response(
- text=endpoint.render_base_template("MTB", html_response),
- content_type="text/html",
- )
-
- # Return JSON for other requests
- return web.json_response({"message": "Welcome to MTB!"})
-
-
-@PromptServer.instance.routes.get("/mtb/debug")
-async def get_debug(request):
- from . import endpoint
-
- reload(endpoint)
- enabled = False
- if "MTB_DEBUG" in os.environ:
- enabled = True
- # Check if the request prefers HTML content
- if "text/html" in request.headers.get("Accept", ""):
- # # Return an HTML page
- html_response = f"""
- MTB Debug Status: {'Enabled' if enabled else 'Disabled'}
- """
- return web.Response(
- text=endpoint.render_base_template("Debug", html_response),
- content_type="text/html",
- )
-
- # Return JSON for other requests
- return web.json_response({"enabled": enabled})
-
+ enabled = True
+ # Check if the request prefers HTML content
+ if "text/html" in request.headers.get("Accept", ""):
+ # # Return an HTML page
+ html_response = f"""
+ MTB Debug Status: {'Enabled' if enabled else 'Disabled'}
+ """
+ return web.Response(
+ text=endpoint.render_base_template("Debug", html_response),
+ content_type="text/html",
+ )
-@PromptServer.instance.routes.get("/mtb/actions")
-async def no_route(request):
- from . import endpoint
+ # Return JSON for other requests
+ return web.json_response({"enabled": enabled})
- if "text/html" in request.headers.get("Accept", ""):
- html_response = f"""
- Actions has no get for now...
- """
- return web.Response(
- text=endpoint.render_base_template("Actions", html_response),
- content_type="text/html",
- )
- return web.json_response({"message": "actions has no get for now"})
+ @PromptServer.instance.routes.get("/mtb/actions")
+ async def no_route(request):
+ from . import endpoint
+ if "text/html" in request.headers.get("Accept", ""):
+ html_response = f"""
+ Actions has no get for now...
+ """
+ return web.Response(
+ text=endpoint.render_base_template("Actions", html_response),
+ content_type="text/html",
+ )
+ return web.json_response({"message": "actions has no get for now"})
-@PromptServer.instance.routes.post("/mtb/actions")
-async def do_action(request):
- from . import endpoint
+ @PromptServer.instance.routes.post("/mtb/actions")
+ async def do_action(request):
+ from . import endpoint
- reload(endpoint)
+ reload(endpoint)
- return await endpoint.do_action(request)
+ return await endpoint.do_action(request)
# - WAS Dictionary
diff --git a/nodes/faceenhance.py b/nodes/faceenhance.py
index 45d4fde..e922559 100644
--- a/nodes/faceenhance.py
+++ b/nodes/faceenhance.py
@@ -28,6 +28,10 @@ def get_models_root(cls):
@classmethod
def get_models(cls):
models_path = cls.get_models_root()
+
+ if not models_path.exists():
+ log.warning(f"No models found at {models_path}")
+ return []
return [
x
diff --git a/pyrightconfig.json b/pyrightconfig.json
index 2d041e3..4d820f5 100644
--- a/pyrightconfig.json
+++ b/pyrightconfig.json
@@ -13,6 +13,6 @@
"reportMissingImports": true,
"reportMissingTypeStubs": false,
"pythonVersion": "3.10",
- "pythonPlatform": "Windows",
+ "pythonPlatform": "All",
"reportOptionalMemberAccess": "none"
}
\ No newline at end of file
diff --git a/web/imageFeed.js b/web/imageFeed.js
index 2c33d1c..824d1ea 100644
--- a/web/imageFeed.js
+++ b/web/imageFeed.js
@@ -53,10 +53,21 @@ let currentImageIndex = 0
const imageUrls = []
let image_menu = null
+let activated = true
app.registerExtension({
name: 'mtb.ImageFeed',
- setup: async () => {
+ init: async () => {
+ const pythongossFeed = app.extensions.find(
+ (e) => e.name == 'pysssss.ImageFeed'
+ )
+ if (pythongossFeed) {
+ console.warn(
+ "[mtb] - Aborting the loading of mtb's imageFeed in favor of pysssss.ImageFeed"
+ )
+ activated = false // just in case other methods are added later on
+ return
+ }
// - HTML & CSS
//- lightbox
const lightboxContainer = document.createElement('div')
@@ -209,6 +220,9 @@ app.registerExtension({
Object.assign(but.style, {
height: '120px',
width: '120px',
+ border: 'none',
+ padding: 0,
+ margin: 0,
})
Object.assign(img.style, {
width: '100%',
diff --git a/web/mtb_widgets.js b/web/mtb_widgets.js
index 4ea2136..c5f36b4 100644
--- a/web/mtb_widgets.js
+++ b/web/mtb_widgets.js
@@ -688,7 +688,34 @@ const mtb_widgets = {
}
break
}
- case 'Save Gif (mtb)': {
+ //TODO: remove this non sense
+ case 'Get Batch From History (mtb)': {
+ const onNodeCreated = nodeType.prototype.onNodeCreated
+ nodeType.prototype.onNodeCreated = function () {
+ const r = onNodeCreated
+ ? onNodeCreated.apply(this, arguments)
+ : undefined
+ const internal_count = this.widgets.find(
+ (w) => w.name === 'internal_count'
+ )
+ shared.hideWidgetForGood(this, internal_count)
+ internal_count.afterQueued = function () {
+ this.value++
+ }
+
+ return r
+ }
+
+ const onExecuted = nodeType.prototype.onExecuted
+ nodeType.prototype.onExecuted = function (message) {
+ const r = onExecuted ? onExecuted.apply(this, message) : undefined
+ return r
+ }
+
+ break
+ }
+ case 'Save Gif (mtb)':
+ case 'Save Animated Image (mtb)': {
const onExecuted = nodeType.prototype.onExecuted
nodeType.prototype.onExecuted = function (message) {
const prefix = 'anything_'
@@ -704,15 +731,25 @@ const mtb_widgets = {
}
let imgURLs = []
- if (message && message.gif) {
- imgURLs = imgURLs.concat(
- message.gif.map((params) => {
- return api.apiURL(
- '/view?' + new URLSearchParams(params).toString()
- )
- })
- )
-
+ if (message) {
+ if (message.gif) {
+ imgURLs = imgURLs.concat(
+ message.gif.map((params) => {
+ return api.apiURL(
+ '/view?' + new URLSearchParams(params).toString()
+ )
+ })
+ )
+ }
+ if (message.apng) {
+ imgURLs = imgURLs.concat(
+ message.apng.map((params) => {
+ return api.apiURL(
+ '/view?' + new URLSearchParams(params).toString()
+ )
+ })
+ )
+ }
let i = 0
for (const img of imgURLs) {
const w = this.addCustomWidget(