Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge masked loss + reg image fixes #1080

Merged
merged 6 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,15 +559,19 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
or backend["id"] in StateTracker.get_data_backends()
):
raise ValueError("Each dataset needs a unique 'id' field.")
info_log(f"Configuring data backend: {backend['id']}")
conditioning_type = backend.get("conditioning_type")
if backend.get("dataset_type") == "conditioning" or conditioning_type is not None:
backend['dataset_type'] = 'conditioning'
resolution_type = backend.get("resolution_type", args.resolution_type)
if resolution_type == "pixel_area":
pixel_edge_length = backend.get("resolution")
pixel_edge_length = backend.get("resolution", int(args.resolution))
if pixel_edge_length is None or (
type(pixel_edge_length) is not int
or not str(pixel_edge_length).isdigit()
):
raise ValueError(
f"Resolution type 'pixel_area' requires a 'resolution' field to be set in the backend config using an integer in the format: 1024"
f"Resolution type 'pixel_area' requires a 'resolution' field to be set in the backend config using an integer in the format: 1024, but {pixel_edge_length} was given"
)
# we'll convert pixel_area to area
backend["resolution_type"] = "area"
Expand Down Expand Up @@ -596,7 +600,6 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
backend["minimum_image_size"] * backend["minimum_image_size"]
) / 1_000_000

info_log(f"Configuring data backend: {backend['id']}")
# Retrieve some config file overrides for commandline arguments, eg. cropping
init_backend = init_backend_config(backend, args, accelerator)
StateTracker.set_data_backend_config(
Expand Down Expand Up @@ -748,9 +751,6 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
repeats=init_backend["config"].get("repeats", 0),
**metadata_backend_args,
)
conditioning_type = None
if backend.get("dataset_type") == "conditioning":
conditioning_type = backend.get("conditioning_type", "controlnet")

if (
"aspect" not in args.skip_file_discovery
Expand Down Expand Up @@ -1066,6 +1066,7 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize

if (
not args.vae_cache_ondemand
and "vaecache" in init_backend
and "vae" not in args.skip_file_discovery
and "vae" not in backend.get("skip_file_discovery", "")
and "deepfloyd" not in StateTracker.get_args().model_type
Expand Down
2 changes: 1 addition & 1 deletion helpers/image_manipulation/training_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
self.resolution = self.data_backend_config.get("resolution")
self.resolution_type = self.data_backend_config.get("resolution_type")
self.target_size_calculator = resize_helpers.get(self.resolution_type)
if self.target_size_calculator is None:
if self.target_size_calculator is None and conditioning_type not in ["mask", "controlnet"]:
raise ValueError(f"Unknown resolution type: {self.resolution_type}")
self._set_resolution()
self.target_downsample_size = self.data_backend_config.get(
Expand Down
77 changes: 47 additions & 30 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2412,24 +2412,30 @@ def train(self):
# x-prediction requires that we now subtract the noise residual from the prediction to get the target sample.
if (
hasattr(self.noise_scheduler, "config")
and hasattr(self.noise_scheduler.config, "prediction_type")
and self.noise_scheduler.config.prediction_type == "sample"
and hasattr(
self.noise_scheduler.config, "prediction_type"
)
and self.noise_scheduler.config.prediction_type
== "sample"
):
model_pred = model_pred - noise

parent_loss = None

# Compute the per-pixel loss without reducing over spatial dimensions
if self.config.flow_matching:
loss = torch.mean(
((model_pred.float() - target.float()) ** 2).reshape(
target.shape[0], -1
),
1,
)
elif self.config.snr_gamma is None or self.config.snr_gamma == 0:
# For flow matching, compute the per-pixel squared differences
loss = (
model_pred.float() - target.float()
) ** 2 # Shape: (batch_size, C, H, W)
elif (
self.config.snr_gamma is None
or self.config.snr_gamma == 0
):
training_logger.debug("Calculating loss")
loss = self.config.snr_weight * F.mse_loss(
model_pred.float(), target.float(), reduction="none"
)
) # Shape: (batch_size, C, H, W)
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
Expand All @@ -2442,7 +2448,8 @@ def train(self):
== "v_prediction"
or (
self.config.flow_matching
and self.config.flow_matching_loss == "diffusion"
and self.config.flow_matching_loss
== "diffusion"
)
):
snr_divisor = snr + 1
Expand All @@ -2454,52 +2461,62 @@ def train(self):
torch.stack(
[
snr,
self.config.snr_gamma * torch.ones_like(timesteps),
self.config.snr_gamma
* torch.ones_like(timesteps),
],
dim=1,
).min(dim=1)[0]
/ snr_divisor
)
) # Shape: (batch_size,)

# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
# Compute the per-pixel MSE loss without reduction
loss = F.mse_loss(
model_pred.float(), target.float(), reduction="none"
)
) # Shape: (batch_size, C, H, W)

# Reshape mse_loss_weights for broadcasting and apply to loss
mse_loss_weights = mse_loss_weights.view(
-1, 1, 1, 1
) # Shape: (batch_size, 1, 1, 1)
loss = (
loss.mean(dim=list(range(1, len(loss.shape))))
* mse_loss_weights
)
loss * mse_loss_weights
) # Shape: (batch_size, C, H, W)

# Mask the loss using any conditioning data
conditioning_type = batch.get("conditioning_type")
if conditioning_type == "mask":
# adapted from: https://github.com/kohya-ss/sd-scripts/blob/main/library/custom_train_functions.py#L482
# Adapted from:
# https://github.com/kohya-ss/sd-scripts/blob/main/library/custom_train_functions.py#L482
mask_image = (
batch["conditioning_pixel_values"]
.to(dtype=loss.dtype, device=loss.device)[:, 0]
.unsqueeze(1)
)
) # Shape: (batch_size, 1, H', W')
mask_image = torch.nn.functional.interpolate(
mask_image, size=loss.shape[2:], mode="area"
)
mask_image = mask_image / 2 + 0.5
loss = loss * mask_image
) # Resize to match loss spatial dimensions
mask_image = mask_image / 2 + 0.5 # Normalize to [0,1]
loss = loss * mask_image # Element-wise multiplication

# Reduce the loss by averaging over channels and spatial dimensions
loss = loss.mean(
dim=list(range(1, len(loss.shape)))
) # Shape: (batch_size,)

# Further reduce the loss by averaging over the batch dimension
loss = loss.mean() # Scalar value

# reduce loss now
loss = loss.mean()
if is_regularisation_data:
parent_loss = loss

# Gather the losses across all processes for logging (if we use distributed training).
# Gather the losses across all processes for logging (if using distributed training)
avg_loss = self.accelerator.gather(
loss.repeat(self.config.train_batch_size)
).mean()
self.train_loss += (
avg_loss.item() / self.config.gradient_accumulation_steps
avg_loss.item()
/ self.config.gradient_accumulation_steps
)

# Backpropagate
grad_norm = None
if not self.config.disable_accelerator:
Expand Down
161 changes: 143 additions & 18 deletions toolkit/datasets/masked_loss/generate_dataset_masks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,96 @@
import argparse
import os
import shutil
from gradio_client import Client, handle_file

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from typing import Union, Any, Tuple, Dict
from unittest.mock import patch

import numpy as np
import supervision as sv
import torch
from PIL import Image
from gradio_client import handle_file
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers.dynamic_module_utils import get_imports

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor


# Constants
FLORENCE_CHECKPOINT = "microsoft/Florence-2-large"
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = "<OPEN_VOCABULARY_DETECTION>"

SAM_CONFIG = "sam2_hiera_l.yaml"
SAM_CHECKPOINT = "checkpoints/sam2_hiera_large.pt"


def load_sam_image_model(
device: torch.device, config: str = SAM_CONFIG, checkpoint: str = SAM_CHECKPOINT
) -> SAM2ImagePredictor:
model = build_sam2(config, checkpoint, device=device)
return SAM2ImagePredictor(sam_model=model)


def run_sam_inference(
model: Any, image: Image.Image, detections: sv.Detections
) -> sv.Detections:
image_np = np.array(image.convert("RGB"))
model.set_image(image_np)
masks, scores, _ = model.predict(box=detections.xyxy, multimask_output=False)

# Ensure mask dimensions are correct
if len(masks.shape) == 4:
masks = np.squeeze(masks)

detections.mask = masks.astype(bool)
return detections


def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
"""Workaround for specific import issues."""
if not str(filename).endswith("/modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
if "flash_attn" in imports:
imports.remove("flash_attn")
return imports


def load_florence_model(
device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
) -> Tuple[Any, Any]:
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
model = (
AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True)
.to(device)
.eval()
)
processor = AutoProcessor.from_pretrained(checkpoint, trust_remote_code=True)
return model, processor


def run_florence_inference(
model: Any,
processor: Any,
device: torch.device,
image: Image.Image,
task: str,
text: str = "",
) -> Tuple[str, Dict]:
prompt = task + text
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
response = processor.post_process_generation(
generated_text, task=task, image_size=image.size
)
return generated_text, response


def main():
Expand All @@ -27,25 +116,41 @@ def main():
default="person",
help='Text prompt for masking (default: "person").',
)
parser.add_argument(
"--model",
type=str,
default="SkalskiP/florence-sam-masking",
help='Model name to use (default: "SkalskiP/florence-sam-masking").',
)
args = parser.parse_args()
if args.input_dir is None or args.output_dir is None:
import sys

sys.exit(1)

input_path = args.input_dir
output_path = args.output_dir
text_input = args.text_input
model_name = args.model

DEVICE = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)

# Retrieve model
from huggingface_hub import hf_hub_download

print(f"Downloading SAM2 to {os.getcwd()}/checkpoints.")
hf_hub_download(
"SkalskiP/florence-sam-masking",
repo_type="space",
subfolder="checkpoints",
local_dir="./",
filename="sam2_hiera_large.pt",
)

# Load models
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)

# Create the output directory if it doesn't exist
os.makedirs(output_path, exist_ok=True)

# Initialize the Gradio client
client = Client(model_name)

# Get all files in the input directory
files = os.listdir(input_path)

Expand All @@ -65,16 +170,36 @@ def main():
continue
# Predict the mask
try:
mask_filename = client.predict(
image_input=handle_file(full_path),
text_input=text_input,
api_name="/process_image",
image_input = Image.open(full_path)
# cast to RGB
image_input = image_input.convert(
"RGB", dither=None, palette=Image.ADAPTIVE
)
_, result = run_florence_inference(
model=FLORENCE_MODEL,
processor=FLORENCE_PROCESSOR,
device=DEVICE,
image=image_input,
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
text=text_input,
)
detections = sv.Detections.from_lmm(
lmm=sv.LMM.FLORENCE_2, result=result, resolution_wh=image_input.size
)
# Move the generated mask to the output directory
shutil.move(mask_filename, mask_path)
if len(detections) == 0:
print(f"No objects detected in {file}.")
continue
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
# Combine masks if multiple detections
combined_mask = np.any(detections.mask, axis=0)
mask_image = Image.fromarray(combined_mask.astype("uint8") * 255)
mask_image.save(mask_path)
print(f"Saved mask to {mask_path}")
except Exception as e:
print(f"Failed to process {file}: {e}")
# Clean up
os.remove("checkpoints/sam2_hiera_large.pt")
os.rmdir("checkpoints")


if __name__ == "__main__":
Expand Down
Loading
Loading