diff --git a/__init__.py b/__init__.py index 11a713d..f5c0bb9 100644 --- a/__init__.py +++ b/__init__.py @@ -5,11 +5,13 @@ # Project: comfy_mtb # Author: Mel Massadian # Copyright (c) 2023 Mel Massadian -# +# ### import os +# todo: don't override this if the user has that setup already os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" +os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async" import traceback from .log import log, blue_text, cyan_text, get_summary, get_label @@ -112,8 +114,16 @@ def load_nodes(): elif web_extensions_root.exists(): web_tgt = here / "web" + src = web_tgt.as_posix() + dst = web_mtb.as_posix() try: - os.symlink(web_tgt.as_posix(), web_mtb.as_posix()) + if os.name == "nt": + import _winapi + + _winapi.CreateJunction(src, dst) + else: + os.symlink(web_tgt.as_posix(), web_mtb.as_posix()) + except OSError: log.warn(f"Failed to create symlink to {web_mtb}, trying to copy it") try: @@ -174,123 +184,121 @@ def load_nodes(): import logging from .endpoint import endlog +if hasattr(PromptServer, "instance"): -@PromptServer.instance.routes.get("/mtb/status") -async def get_full_library(request): - from . import endpoint + @PromptServer.instance.routes.get("/mtb/status") + async def get_full_library(request): + from . import endpoint - reload(endpoint) + reload(endpoint) - endlog.debug("Getting node registration status") - # Check if the request prefers HTML content - if "text/html" in request.headers.get("Accept", ""): - # # Return an HTML page - html_response = endpoint.render_table( - NODE_CLASS_MAPPINGS_DEBUG, title="Registered" - ) - html_response += endpoint.render_table( - {k: "-" for k in failed}, title="Failed to load" + endlog.debug("Getting node registration status") + # Check if the request prefers HTML content + if "text/html" in request.headers.get("Accept", ""): + # # Return an HTML page + html_response = endpoint.render_table( + NODE_CLASS_MAPPINGS_DEBUG, title="Registered" + ) + html_response += endpoint.render_table( + {k: "-" for k in failed}, title="Failed to load" + ) + + return web.Response( + text=endpoint.render_base_template("MTB", html_response), + content_type="text/html", + ) + + return web.json_response( + { + "registered": NODE_CLASS_MAPPINGS_DEBUG, + "failed": failed, + } ) - return web.Response( - text=endpoint.render_base_template("MTB", html_response), - content_type="text/html", + @PromptServer.instance.routes.post("/mtb/debug") + async def set_debug(request): + json_data = await request.json() + enabled = json_data.get("enabled") + if enabled: + os.environ["MTB_DEBUG"] = "true" + log.setLevel(logging.DEBUG) + log.debug("Debug mode set from API (/mtb/debug POST route)") + + else: + if "MTB_DEBUG" in os.environ: + # del os.environ["MTB_DEBUG"] + os.environ.pop("MTB_DEBUG") + log.setLevel(logging.INFO) + + return web.json_response( + {"message": f"Debug mode {'set' if enabled else 'unset'}"} ) - return web.json_response( - { - "registered": NODE_CLASS_MAPPINGS_DEBUG, - "failed": failed, - } - ) + @PromptServer.instance.routes.get("/mtb") + async def get_home(request): + from . import endpoint + + reload(endpoint) + # Check if the request prefers HTML content + if "text/html" in request.headers.get("Accept", ""): + # # Return an HTML page + html_response = f""" + + """ + return web.Response( + text=endpoint.render_base_template("MTB", html_response), + content_type="text/html", + ) + # Return JSON for other requests + return web.json_response({"message": "Welcome to MTB!"}) -@PromptServer.instance.routes.post("/mtb/debug") -async def set_debug(request): - json_data = await request.json() - enabled = json_data.get("enabled") - if enabled: - os.environ["MTB_DEBUG"] = "true" - log.setLevel(logging.DEBUG) - log.debug("Debug mode set from API (/mtb/debug POST route)") + @PromptServer.instance.routes.get("/mtb/debug") + async def get_debug(request): + from . import endpoint - else: + reload(endpoint) + enabled = False if "MTB_DEBUG" in os.environ: - # del os.environ["MTB_DEBUG"] - os.environ.pop("MTB_DEBUG") - log.setLevel(logging.INFO) - - return web.json_response({"message": f"Debug mode {'set' if enabled else 'unset'}"}) - - -@PromptServer.instance.routes.get("/mtb") -async def get_home(request): - from . import endpoint - - reload(endpoint) - # Check if the request prefers HTML content - if "text/html" in request.headers.get("Accept", ""): - # # Return an HTML page - html_response = f""" - - """ - return web.Response( - text=endpoint.render_base_template("MTB", html_response), - content_type="text/html", - ) - - # Return JSON for other requests - return web.json_response({"message": "Welcome to MTB!"}) - - -@PromptServer.instance.routes.get("/mtb/debug") -async def get_debug(request): - from . import endpoint - - reload(endpoint) - enabled = False - if "MTB_DEBUG" in os.environ: - enabled = True - # Check if the request prefers HTML content - if "text/html" in request.headers.get("Accept", ""): - # # Return an HTML page - html_response = f""" -

MTB Debug Status: {'Enabled' if enabled else 'Disabled'}

- """ - return web.Response( - text=endpoint.render_base_template("Debug", html_response), - content_type="text/html", - ) - - # Return JSON for other requests - return web.json_response({"enabled": enabled}) - + enabled = True + # Check if the request prefers HTML content + if "text/html" in request.headers.get("Accept", ""): + # # Return an HTML page + html_response = f""" +

MTB Debug Status: {'Enabled' if enabled else 'Disabled'}

+ """ + return web.Response( + text=endpoint.render_base_template("Debug", html_response), + content_type="text/html", + ) -@PromptServer.instance.routes.get("/mtb/actions") -async def no_route(request): - from . import endpoint + # Return JSON for other requests + return web.json_response({"enabled": enabled}) - if "text/html" in request.headers.get("Accept", ""): - html_response = f""" -

Actions has no get for now...

- """ - return web.Response( - text=endpoint.render_base_template("Actions", html_response), - content_type="text/html", - ) - return web.json_response({"message": "actions has no get for now"}) + @PromptServer.instance.routes.get("/mtb/actions") + async def no_route(request): + from . import endpoint + if "text/html" in request.headers.get("Accept", ""): + html_response = f""" +

Actions has no get for now...

+ """ + return web.Response( + text=endpoint.render_base_template("Actions", html_response), + content_type="text/html", + ) + return web.json_response({"message": "actions has no get for now"}) -@PromptServer.instance.routes.post("/mtb/actions") -async def do_action(request): - from . import endpoint + @PromptServer.instance.routes.post("/mtb/actions") + async def do_action(request): + from . import endpoint - reload(endpoint) + reload(endpoint) - return await endpoint.do_action(request) + return await endpoint.do_action(request) # - WAS Dictionary diff --git a/nodes/faceenhance.py b/nodes/faceenhance.py index 45d4fde..e922559 100644 --- a/nodes/faceenhance.py +++ b/nodes/faceenhance.py @@ -28,6 +28,10 @@ def get_models_root(cls): @classmethod def get_models(cls): models_path = cls.get_models_root() + + if not models_path.exists(): + log.warning(f"No models found at {models_path}") + return [] return [ x diff --git a/pyrightconfig.json b/pyrightconfig.json index 2d041e3..4d820f5 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -13,6 +13,6 @@ "reportMissingImports": true, "reportMissingTypeStubs": false, "pythonVersion": "3.10", - "pythonPlatform": "Windows", + "pythonPlatform": "All", "reportOptionalMemberAccess": "none" } \ No newline at end of file diff --git a/web/imageFeed.js b/web/imageFeed.js index 2c33d1c..824d1ea 100644 --- a/web/imageFeed.js +++ b/web/imageFeed.js @@ -53,10 +53,21 @@ let currentImageIndex = 0 const imageUrls = [] let image_menu = null +let activated = true app.registerExtension({ name: 'mtb.ImageFeed', - setup: async () => { + init: async () => { + const pythongossFeed = app.extensions.find( + (e) => e.name == 'pysssss.ImageFeed' + ) + if (pythongossFeed) { + console.warn( + "[mtb] - Aborting the loading of mtb's imageFeed in favor of pysssss.ImageFeed" + ) + activated = false // just in case other methods are added later on + return + } // - HTML & CSS //- lightbox const lightboxContainer = document.createElement('div') @@ -209,6 +220,9 @@ app.registerExtension({ Object.assign(but.style, { height: '120px', width: '120px', + border: 'none', + padding: 0, + margin: 0, }) Object.assign(img.style, { width: '100%', diff --git a/web/mtb_widgets.js b/web/mtb_widgets.js index 4ea2136..c5f36b4 100644 --- a/web/mtb_widgets.js +++ b/web/mtb_widgets.js @@ -688,7 +688,34 @@ const mtb_widgets = { } break } - case 'Save Gif (mtb)': { + //TODO: remove this non sense + case 'Get Batch From History (mtb)': { + const onNodeCreated = nodeType.prototype.onNodeCreated + nodeType.prototype.onNodeCreated = function () { + const r = onNodeCreated + ? onNodeCreated.apply(this, arguments) + : undefined + const internal_count = this.widgets.find( + (w) => w.name === 'internal_count' + ) + shared.hideWidgetForGood(this, internal_count) + internal_count.afterQueued = function () { + this.value++ + } + + return r + } + + const onExecuted = nodeType.prototype.onExecuted + nodeType.prototype.onExecuted = function (message) { + const r = onExecuted ? onExecuted.apply(this, message) : undefined + return r + } + + break + } + case 'Save Gif (mtb)': + case 'Save Animated Image (mtb)': { const onExecuted = nodeType.prototype.onExecuted nodeType.prototype.onExecuted = function (message) { const prefix = 'anything_' @@ -704,15 +731,25 @@ const mtb_widgets = { } let imgURLs = [] - if (message && message.gif) { - imgURLs = imgURLs.concat( - message.gif.map((params) => { - return api.apiURL( - '/view?' + new URLSearchParams(params).toString() - ) - }) - ) - + if (message) { + if (message.gif) { + imgURLs = imgURLs.concat( + message.gif.map((params) => { + return api.apiURL( + '/view?' + new URLSearchParams(params).toString() + ) + }) + ) + } + if (message.apng) { + imgURLs = imgURLs.concat( + message.apng.map((params) => { + return api.apiURL( + '/view?' + new URLSearchParams(params).toString() + ) + }) + ) + } let i = 0 for (const img of imgURLs) { const w = this.addCustomWidget(