Skip to content

Commit

Permalink
Merge pull request #1011 from bghira/debug/webhook-contents
Browse files Browse the repository at this point in the history
fix webhook contents for discord
  • Loading branch information
bghira authored Sep 30, 2024
2 parents 97adf06 + 8831a38 commit e615252
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
15 changes: 10 additions & 5 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
from helpers.training.optimizer_param import optimizer_choices

bf16_only_optims = [
key for key, value in optimizer_choices.items() if value.get("precision", "any") == "bf16"
key
for key, value in optimizer_choices.items()
if value.get("precision", "any") == "bf16"
]
any_precision_optims = [
key for key, value in optimizer_choices.items() if value.get("precision", "any") == "any"
key
for key, value in optimizer_choices.items()
if value.get("precision", "any") == "any"
]
model_classes = {
"full": [
Expand All @@ -17,10 +21,10 @@
"pixart_sigma",
"kolors",
"sd3",
"stable_diffusion_legacy",
"legacy",
],
"lora": ["flux", "sdxl", "kolors", "sd3", "stable_diffusion_legacy"],
"controlnet": ["sdxl", "stable_diffusion_legacy"],
"lora": ["flux", "sdxl", "kolors", "sd3", "legacy"],
"controlnet": ["sdxl", "legacy"],
}

default_models = {
Expand All @@ -30,6 +34,7 @@
"kolors": "kwai-kolors/kolors-diffusers",
"terminus": "ptx0/terminus-xl-velocity-v2",
"sd3": "stabilityai/stable-diffusion-3-medium-diffusers",
"legacy": "stabilityai/stable-diffusion-2-1-base",
}

default_cfg = {
Expand Down
12 changes: 10 additions & 2 deletions helpers/webhooks/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def _send_request(
# Prepare Discord-style payload
data = {"content": f"{self.message_prefix}{message}"}
files = self._prepare_images(images)
request_args = {
"data": data,
"files": files if self.webhook_type == "discord" else None,
}
elif self.webhook_type == "raw":
# Prepare raw data payload for direct POST
if raw_request:
Expand All @@ -70,16 +74,20 @@ def _send_request(
),
}
files = None
request_args = {
"json": data,
"files": None,
}
else:
logger.error(f"Unsupported webhook type: {self.webhook_type}")
return

# Send request
try:
logger.debug(f"Sending webhook request: {request_args}")
post_result = requests.post(
self.webhook_url,
json=data,
files=files if self.webhook_type == "discord" else None,
**request_args,
)
post_result.raise_for_status()
except Exception as e:
Expand Down
10 changes: 5 additions & 5 deletions tests/test_webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def test_send_message_info_level(self, mock_post):
# Capture the call arguments
args, kwargs = mock_post.call_args
# Assuming the message is sent in 'data' parameter
self.assertIn("json", kwargs)
self.assertIn(message, kwargs["json"].get("content"))
self.assertIn("data", kwargs)
self.assertIn(message, kwargs["data"].get("content"))

@patch("requests.post")
def test_debug_message_wont_send(self, mock_post):
Expand All @@ -68,9 +68,9 @@ def test_send_with_images(self, mock_post):
self.assertIn("files", kwargs)
self.assertEqual(len(kwargs["files"]), 1)
# Check that the message is in the 'data' parameter
content = kwargs.get("json", {}).get("content", "")
content = kwargs.get("data", {}).get("content", "")
self.assertIn(self.mock_config_instance.values.get("message_prefix"), content)
self.assertIn("json", kwargs, f"Check data for contents: {kwargs}")
self.assertIn("data", kwargs, f"Check data for contents: {kwargs}")
self.assertIn(message, content)

@patch("requests.post")
Expand All @@ -84,7 +84,7 @@ def test_response_storage(self, mock_post):
self.assertEqual(self.handler.stored_response, mock_response.headers)
# Also check that the message is sent
args, kwargs = mock_post.call_args
content = kwargs.get("json", {}).get("content", "")
content = kwargs.get("data", {}).get("content", "")
self.assertIn(self.mock_config_instance.values.get("message_prefix"), content)
self.assertIn("Test message", content)

Expand Down

0 comments on commit e615252

Please sign in to comment.