Skip to content

Commit

Permalink
fix: ✨ various small things
Browse files Browse the repository at this point in the history
- removed border on imagefeed images.
- don't load mtb.imageFeed if the user has pythongoss's version already.
- fix the promptserver issue when importing mtb from a jupyter notebook
- fix: if the user doesn't have the facemodels downloaded it would crash
- added an internal counter to batchfromhistory to invalidate it at each
  frame, which might not be a good idea.
  • Loading branch information
melMass committed Jul 28, 2023
1 parent 889f08c commit 0e311cf
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 114 deletions.
212 changes: 110 additions & 102 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
# Project: comfy_mtb
# Author: Mel Massadian
# Copyright (c) 2023 Mel Massadian
#
#
###
import os

# todo: don't override this if the user has that setup already
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"

import traceback
from .log import log, blue_text, cyan_text, get_summary, get_label
Expand Down Expand Up @@ -112,8 +114,16 @@ def load_nodes():

elif web_extensions_root.exists():
web_tgt = here / "web"
src = web_tgt.as_posix()
dst = web_mtb.as_posix()
try:
os.symlink(web_tgt.as_posix(), web_mtb.as_posix())
if os.name == "nt":
import _winapi

_winapi.CreateJunction(src, dst)
else:
os.symlink(web_tgt.as_posix(), web_mtb.as_posix())

except OSError:
log.warn(f"Failed to create symlink to {web_mtb}, trying to copy it")
try:
Expand Down Expand Up @@ -174,123 +184,121 @@ def load_nodes():
import logging
from .endpoint import endlog

if hasattr(PromptServer, "instance"):

@PromptServer.instance.routes.get("/mtb/status")
async def get_full_library(request):
from . import endpoint
@PromptServer.instance.routes.get("/mtb/status")
async def get_full_library(request):
from . import endpoint

reload(endpoint)
reload(endpoint)

endlog.debug("Getting node registration status")
# Check if the request prefers HTML content
if "text/html" in request.headers.get("Accept", ""):
# # Return an HTML page
html_response = endpoint.render_table(
NODE_CLASS_MAPPINGS_DEBUG, title="Registered"
)
html_response += endpoint.render_table(
{k: "-" for k in failed}, title="Failed to load"
endlog.debug("Getting node registration status")
# Check if the request prefers HTML content
if "text/html" in request.headers.get("Accept", ""):
# # Return an HTML page
html_response = endpoint.render_table(
NODE_CLASS_MAPPINGS_DEBUG, title="Registered"
)
html_response += endpoint.render_table(
{k: "-" for k in failed}, title="Failed to load"
)

return web.Response(
text=endpoint.render_base_template("MTB", html_response),
content_type="text/html",
)

return web.json_response(
{
"registered": NODE_CLASS_MAPPINGS_DEBUG,
"failed": failed,
}
)

return web.Response(
text=endpoint.render_base_template("MTB", html_response),
content_type="text/html",
@PromptServer.instance.routes.post("/mtb/debug")
async def set_debug(request):
json_data = await request.json()
enabled = json_data.get("enabled")
if enabled:
os.environ["MTB_DEBUG"] = "true"
log.setLevel(logging.DEBUG)
log.debug("Debug mode set from API (/mtb/debug POST route)")

else:
if "MTB_DEBUG" in os.environ:
# del os.environ["MTB_DEBUG"]
os.environ.pop("MTB_DEBUG")
log.setLevel(logging.INFO)

return web.json_response(
{"message": f"Debug mode {'set' if enabled else 'unset'}"}
)

return web.json_response(
{
"registered": NODE_CLASS_MAPPINGS_DEBUG,
"failed": failed,
}
)
@PromptServer.instance.routes.get("/mtb")
async def get_home(request):
from . import endpoint

reload(endpoint)
# Check if the request prefers HTML content
if "text/html" in request.headers.get("Accept", ""):
# # Return an HTML page
html_response = f"""
<div class="flex-container menu">
<a href="/mtb/debug">debug</a>
<a href="/mtb/status">status</a>
</div>
"""
return web.Response(
text=endpoint.render_base_template("MTB", html_response),
content_type="text/html",
)

# Return JSON for other requests
return web.json_response({"message": "Welcome to MTB!"})

@PromptServer.instance.routes.post("/mtb/debug")
async def set_debug(request):
json_data = await request.json()
enabled = json_data.get("enabled")
if enabled:
os.environ["MTB_DEBUG"] = "true"
log.setLevel(logging.DEBUG)
log.debug("Debug mode set from API (/mtb/debug POST route)")
@PromptServer.instance.routes.get("/mtb/debug")
async def get_debug(request):
from . import endpoint

else:
reload(endpoint)
enabled = False
if "MTB_DEBUG" in os.environ:
# del os.environ["MTB_DEBUG"]
os.environ.pop("MTB_DEBUG")
log.setLevel(logging.INFO)

return web.json_response({"message": f"Debug mode {'set' if enabled else 'unset'}"})


@PromptServer.instance.routes.get("/mtb")
async def get_home(request):
from . import endpoint

reload(endpoint)
# Check if the request prefers HTML content
if "text/html" in request.headers.get("Accept", ""):
# # Return an HTML page
html_response = f"""
<div class="flex-container menu">
<a href="/mtb/debug">debug</a>
<a href="/mtb/status">status</a>
</div>
"""
return web.Response(
text=endpoint.render_base_template("MTB", html_response),
content_type="text/html",
)

# Return JSON for other requests
return web.json_response({"message": "Welcome to MTB!"})


@PromptServer.instance.routes.get("/mtb/debug")
async def get_debug(request):
from . import endpoint

reload(endpoint)
enabled = False
if "MTB_DEBUG" in os.environ:
enabled = True
# Check if the request prefers HTML content
if "text/html" in request.headers.get("Accept", ""):
# # Return an HTML page
html_response = f"""
<h1>MTB Debug Status: {'Enabled' if enabled else 'Disabled'}</h1>
"""
return web.Response(
text=endpoint.render_base_template("Debug", html_response),
content_type="text/html",
)

# Return JSON for other requests
return web.json_response({"enabled": enabled})

enabled = True
# Check if the request prefers HTML content
if "text/html" in request.headers.get("Accept", ""):
# # Return an HTML page
html_response = f"""
<h1>MTB Debug Status: {'Enabled' if enabled else 'Disabled'}</h1>
"""
return web.Response(
text=endpoint.render_base_template("Debug", html_response),
content_type="text/html",
)

@PromptServer.instance.routes.get("/mtb/actions")
async def no_route(request):
from . import endpoint
# Return JSON for other requests
return web.json_response({"enabled": enabled})

if "text/html" in request.headers.get("Accept", ""):
html_response = f"""
<h1>Actions has no get for now...</h1>
"""
return web.Response(
text=endpoint.render_base_template("Actions", html_response),
content_type="text/html",
)
return web.json_response({"message": "actions has no get for now"})
@PromptServer.instance.routes.get("/mtb/actions")
async def no_route(request):
from . import endpoint

if "text/html" in request.headers.get("Accept", ""):
html_response = f"""
<h1>Actions has no get for now...</h1>
"""
return web.Response(
text=endpoint.render_base_template("Actions", html_response),
content_type="text/html",
)
return web.json_response({"message": "actions has no get for now"})

@PromptServer.instance.routes.post("/mtb/actions")
async def do_action(request):
from . import endpoint
@PromptServer.instance.routes.post("/mtb/actions")
async def do_action(request):
from . import endpoint

reload(endpoint)
reload(endpoint)

return await endpoint.do_action(request)
return await endpoint.do_action(request)


# - WAS Dictionary
Expand Down
4 changes: 4 additions & 0 deletions nodes/faceenhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def get_models_root(cls):
@classmethod
def get_models(cls):
models_path = cls.get_models_root()

if not models_path.exists():
log.warning(f"No models found at {models_path}")
return []

return [
x
Expand Down
2 changes: 1 addition & 1 deletion pyrightconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
"reportMissingImports": true,
"reportMissingTypeStubs": false,
"pythonVersion": "3.10",
"pythonPlatform": "Windows",
"pythonPlatform": "All",
"reportOptionalMemberAccess": "none"
}
16 changes: 15 additions & 1 deletion web/imageFeed.js
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,21 @@ let currentImageIndex = 0
const imageUrls = []

let image_menu = null
let activated = true

app.registerExtension({
name: 'mtb.ImageFeed',
setup: async () => {
init: async () => {
const pythongossFeed = app.extensions.find(
(e) => e.name == 'pysssss.ImageFeed'
)
if (pythongossFeed) {
console.warn(
"[mtb] - Aborting the loading of mtb's imageFeed in favor of pysssss.ImageFeed"
)
activated = false // just in case other methods are added later on
return
}
// - HTML & CSS
//- lightbox
const lightboxContainer = document.createElement('div')
Expand Down Expand Up @@ -209,6 +220,9 @@ app.registerExtension({
Object.assign(but.style, {
height: '120px',
width: '120px',
border: 'none',
padding: 0,
margin: 0,
})
Object.assign(img.style, {
width: '100%',
Expand Down
57 changes: 47 additions & 10 deletions web/mtb_widgets.js
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,34 @@ const mtb_widgets = {
}
break
}
case 'Save Gif (mtb)': {
//TODO: remove this non sense
case 'Get Batch From History (mtb)': {
const onNodeCreated = nodeType.prototype.onNodeCreated
nodeType.prototype.onNodeCreated = function () {
const r = onNodeCreated
? onNodeCreated.apply(this, arguments)
: undefined
const internal_count = this.widgets.find(
(w) => w.name === 'internal_count'
)
shared.hideWidgetForGood(this, internal_count)
internal_count.afterQueued = function () {
this.value++
}

return r
}

const onExecuted = nodeType.prototype.onExecuted
nodeType.prototype.onExecuted = function (message) {
const r = onExecuted ? onExecuted.apply(this, message) : undefined
return r
}

break
}
case 'Save Gif (mtb)':
case 'Save Animated Image (mtb)': {
const onExecuted = nodeType.prototype.onExecuted
nodeType.prototype.onExecuted = function (message) {
const prefix = 'anything_'
Expand All @@ -704,15 +731,25 @@ const mtb_widgets = {
}

let imgURLs = []
if (message && message.gif) {
imgURLs = imgURLs.concat(
message.gif.map((params) => {
return api.apiURL(
'/view?' + new URLSearchParams(params).toString()
)
})
)

if (message) {
if (message.gif) {
imgURLs = imgURLs.concat(
message.gif.map((params) => {
return api.apiURL(
'/view?' + new URLSearchParams(params).toString()
)
})
)
}
if (message.apng) {
imgURLs = imgURLs.concat(
message.apng.map((params) => {
return api.apiURL(
'/view?' + new URLSearchParams(params).toString()
)
})
)
}
let i = 0
for (const img of imgURLs) {
const w = this.addCustomWidget(
Expand Down

0 comments on commit 0e311cf

Please sign in to comment.