Skip to content

Commit

Permalink
Merge pull request #1080 from bghira/main
Browse files Browse the repository at this point in the history
merge masked loss + reg image fixes
  • Loading branch information
bghira authored Oct 19, 2024
2 parents 8bf644f + 68c69d3 commit 056d1b9
Show file tree
Hide file tree
Showing 6 changed files with 287 additions and 55 deletions.
13 changes: 7 additions & 6 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,15 +559,19 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
or backend["id"] in StateTracker.get_data_backends()
):
raise ValueError("Each dataset needs a unique 'id' field.")
info_log(f"Configuring data backend: {backend['id']}")
conditioning_type = backend.get("conditioning_type")
if backend.get("dataset_type") == "conditioning" or conditioning_type is not None:
backend['dataset_type'] = 'conditioning'
resolution_type = backend.get("resolution_type", args.resolution_type)
if resolution_type == "pixel_area":
pixel_edge_length = backend.get("resolution")
pixel_edge_length = backend.get("resolution", int(args.resolution))
if pixel_edge_length is None or (
type(pixel_edge_length) is not int
or not str(pixel_edge_length).isdigit()
):
raise ValueError(
f"Resolution type 'pixel_area' requires a 'resolution' field to be set in the backend config using an integer in the format: 1024"
f"Resolution type 'pixel_area' requires a 'resolution' field to be set in the backend config using an integer in the format: 1024, but {pixel_edge_length} was given"
)
# we'll convert pixel_area to area
backend["resolution_type"] = "area"
Expand Down Expand Up @@ -596,7 +600,6 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
backend["minimum_image_size"] * backend["minimum_image_size"]
) / 1_000_000

info_log(f"Configuring data backend: {backend['id']}")
# Retrieve some config file overrides for commandline arguments, eg. cropping
init_backend = init_backend_config(backend, args, accelerator)
StateTracker.set_data_backend_config(
Expand Down Expand Up @@ -748,9 +751,6 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
repeats=init_backend["config"].get("repeats", 0),
**metadata_backend_args,
)
conditioning_type = None
if backend.get("dataset_type") == "conditioning":
conditioning_type = backend.get("conditioning_type", "controlnet")

if (
"aspect" not in args.skip_file_discovery
Expand Down Expand Up @@ -1066,6 +1066,7 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize

if (
not args.vae_cache_ondemand
and "vaecache" in init_backend
and "vae" not in args.skip_file_discovery
and "vae" not in backend.get("skip_file_discovery", "")
and "deepfloyd" not in StateTracker.get_args().model_type
Expand Down
2 changes: 1 addition & 1 deletion helpers/image_manipulation/training_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
self.resolution = self.data_backend_config.get("resolution")
self.resolution_type = self.data_backend_config.get("resolution_type")
self.target_size_calculator = resize_helpers.get(self.resolution_type)
if self.target_size_calculator is None:
if self.target_size_calculator is None and conditioning_type not in ["mask", "controlnet"]:
raise ValueError(f"Unknown resolution type: {self.resolution_type}")
self._set_resolution()
self.target_downsample_size = self.data_backend_config.get(
Expand Down
77 changes: 47 additions & 30 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2412,24 +2412,30 @@ def train(self):
# x-prediction requires that we now subtract the noise residual from the prediction to get the target sample.
if (
hasattr(self.noise_scheduler, "config")
and hasattr(self.noise_scheduler.config, "prediction_type")
and self.noise_scheduler.config.prediction_type == "sample"
and hasattr(
self.noise_scheduler.config, "prediction_type"
)
and self.noise_scheduler.config.prediction_type
== "sample"
):
model_pred = model_pred - noise

parent_loss = None

# Compute the per-pixel loss without reducing over spatial dimensions
if self.config.flow_matching:
loss = torch.mean(
((model_pred.float() - target.float()) ** 2).reshape(
target.shape[0], -1
),
1,
)
elif self.config.snr_gamma is None or self.config.snr_gamma == 0:
# For flow matching, compute the per-pixel squared differences
loss = (
model_pred.float() - target.float()
) ** 2 # Shape: (batch_size, C, H, W)
elif (
self.config.snr_gamma is None
or self.config.snr_gamma == 0
):
training_logger.debug("Calculating loss")
loss = self.config.snr_weight * F.mse_loss(
model_pred.float(), target.float(), reduction="none"
)
) # Shape: (batch_size, C, H, W)
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
Expand All @@ -2442,7 +2448,8 @@ def train(self):
== "v_prediction"
or (
self.config.flow_matching
and self.config.flow_matching_loss == "diffusion"
and self.config.flow_matching_loss
== "diffusion"
)
):
snr_divisor = snr + 1
Expand All @@ -2454,52 +2461,62 @@ def train(self):
torch.stack(
[
snr,
self.config.snr_gamma * torch.ones_like(timesteps),
self.config.snr_gamma
* torch.ones_like(timesteps),
],
dim=1,
).min(dim=1)[0]
/ snr_divisor
)
) # Shape: (batch_size,)

# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
# Compute the per-pixel MSE loss without reduction
loss = F.mse_loss(
model_pred.float(), target.float(), reduction="none"
)
) # Shape: (batch_size, C, H, W)

# Reshape mse_loss_weights for broadcasting and apply to loss
mse_loss_weights = mse_loss_weights.view(
-1, 1, 1, 1
) # Shape: (batch_size, 1, 1, 1)
loss = (
loss.mean(dim=list(range(1, len(loss.shape))))
* mse_loss_weights
)
loss * mse_loss_weights
) # Shape: (batch_size, C, H, W)

# Mask the loss using any conditioning data
conditioning_type = batch.get("conditioning_type")
if conditioning_type == "mask":
# adapted from: https://github.com/kohya-ss/sd-scripts/blob/main/library/custom_train_functions.py#L482
# Adapted from:
# https://github.com/kohya-ss/sd-scripts/blob/main/library/custom_train_functions.py#L482
mask_image = (
batch["conditioning_pixel_values"]
.to(dtype=loss.dtype, device=loss.device)[:, 0]
.unsqueeze(1)
)
) # Shape: (batch_size, 1, H', W')
mask_image = torch.nn.functional.interpolate(
mask_image, size=loss.shape[2:], mode="area"
)
mask_image = mask_image / 2 + 0.5
loss = loss * mask_image
) # Resize to match loss spatial dimensions
mask_image = mask_image / 2 + 0.5 # Normalize to [0,1]
loss = loss * mask_image # Element-wise multiplication

# Reduce the loss by averaging over channels and spatial dimensions
loss = loss.mean(
dim=list(range(1, len(loss.shape)))
) # Shape: (batch_size,)

# Further reduce the loss by averaging over the batch dimension
loss = loss.mean() # Scalar value

# reduce loss now
loss = loss.mean()
if is_regularisation_data:
parent_loss = loss

# Gather the losses across all processes for logging (if we use distributed training).
# Gather the losses across all processes for logging (if using distributed training)
avg_loss = self.accelerator.gather(
loss.repeat(self.config.train_batch_size)
).mean()
self.train_loss += (
avg_loss.item() / self.config.gradient_accumulation_steps
avg_loss.item()
/ self.config.gradient_accumulation_steps
)

# Backpropagate
grad_norm = None
if not self.config.disable_accelerator:
Expand Down
161 changes: 143 additions & 18 deletions toolkit/datasets/masked_loss/generate_dataset_masks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,96 @@
import argparse
import os
import shutil
from gradio_client import Client, handle_file

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from typing import Union, Any, Tuple, Dict
from unittest.mock import patch

import numpy as np
import supervision as sv
import torch
from PIL import Image
from gradio_client import handle_file
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers.dynamic_module_utils import get_imports

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor


# Constants
FLORENCE_CHECKPOINT = "microsoft/Florence-2-large"
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = "<OPEN_VOCABULARY_DETECTION>"

SAM_CONFIG = "sam2_hiera_l.yaml"
SAM_CHECKPOINT = "checkpoints/sam2_hiera_large.pt"


def load_sam_image_model(
device: torch.device, config: str = SAM_CONFIG, checkpoint: str = SAM_CHECKPOINT
) -> SAM2ImagePredictor:
model = build_sam2(config, checkpoint, device=device)
return SAM2ImagePredictor(sam_model=model)


def run_sam_inference(
model: Any, image: Image.Image, detections: sv.Detections
) -> sv.Detections:
image_np = np.array(image.convert("RGB"))
model.set_image(image_np)
masks, scores, _ = model.predict(box=detections.xyxy, multimask_output=False)

# Ensure mask dimensions are correct
if len(masks.shape) == 4:
masks = np.squeeze(masks)

detections.mask = masks.astype(bool)
return detections


def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
"""Workaround for specific import issues."""
if not str(filename).endswith("/modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
if "flash_attn" in imports:
imports.remove("flash_attn")
return imports


def load_florence_model(
device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
) -> Tuple[Any, Any]:
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
model = (
AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True)
.to(device)
.eval()
)
processor = AutoProcessor.from_pretrained(checkpoint, trust_remote_code=True)
return model, processor


def run_florence_inference(
model: Any,
processor: Any,
device: torch.device,
image: Image.Image,
task: str,
text: str = "",
) -> Tuple[str, Dict]:
prompt = task + text
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
response = processor.post_process_generation(
generated_text, task=task, image_size=image.size
)
return generated_text, response


def main():
Expand All @@ -27,25 +116,41 @@ def main():
default="person",
help='Text prompt for masking (default: "person").',
)
parser.add_argument(
"--model",
type=str,
default="SkalskiP/florence-sam-masking",
help='Model name to use (default: "SkalskiP/florence-sam-masking").',
)
args = parser.parse_args()
if args.input_dir is None or args.output_dir is None:
import sys

sys.exit(1)

input_path = args.input_dir
output_path = args.output_dir
text_input = args.text_input
model_name = args.model

DEVICE = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)

# Retrieve model
from huggingface_hub import hf_hub_download

print(f"Downloading SAM2 to {os.getcwd()}/checkpoints.")
hf_hub_download(
"SkalskiP/florence-sam-masking",
repo_type="space",
subfolder="checkpoints",
local_dir="./",
filename="sam2_hiera_large.pt",
)

# Load models
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)

# Create the output directory if it doesn't exist
os.makedirs(output_path, exist_ok=True)

# Initialize the Gradio client
client = Client(model_name)

# Get all files in the input directory
files = os.listdir(input_path)

Expand All @@ -65,16 +170,36 @@ def main():
continue
# Predict the mask
try:
mask_filename = client.predict(
image_input=handle_file(full_path),
text_input=text_input,
api_name="/process_image",
image_input = Image.open(full_path)
# cast to RGB
image_input = image_input.convert(
"RGB", dither=None, palette=Image.ADAPTIVE
)
_, result = run_florence_inference(
model=FLORENCE_MODEL,
processor=FLORENCE_PROCESSOR,
device=DEVICE,
image=image_input,
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
text=text_input,
)
detections = sv.Detections.from_lmm(
lmm=sv.LMM.FLORENCE_2, result=result, resolution_wh=image_input.size
)
# Move the generated mask to the output directory
shutil.move(mask_filename, mask_path)
if len(detections) == 0:
print(f"No objects detected in {file}.")
continue
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
# Combine masks if multiple detections
combined_mask = np.any(detections.mask, axis=0)
mask_image = Image.fromarray(combined_mask.astype("uint8") * 255)
mask_image.save(mask_path)
print(f"Saved mask to {mask_path}")
except Exception as e:
print(f"Failed to process {file}: {e}")
# Clean up
os.remove("checkpoints/sam2_hiera_large.pt")
os.rmdir("checkpoints")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 056d1b9

Please sign in to comment.