Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion hugging_face/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,58 @@
import warnings
warnings.filterwarnings("ignore")

# * Memory management functions for GPU optimization
def move_sam_to_cpu():
"""Move SAM model to CPU to free GPU memory for matting"""
global model
print("Moving SAM to CPU...")
try:
model.samcontroler.sam_controler.model.cpu()
torch.cuda.empty_cache()
print("SAM moved to CPU successfully")
except Exception as e:
print(f"Error moving SAM to CPU: {e}")

def move_sam_to_gpu():
"""Move SAM model back to GPU after matting"""
global model
print("Moving SAM to GPU...")
try:
model.samcontroler.sam_controler.model.to(model.samcontroler.sam_controler.device)
torch.cuda.empty_cache()
print("SAM moved to GPU successfully")
except Exception as e:
print(f"Error moving SAM to GPU: {e}")

def move_matting_to_cpu():
"""Move MatAnyone model to CPU to free GPU memory for SAM"""
global matanyone_model
print("Moving MatAnyone to CPU...")
try:
matanyone_model.cpu()
torch.cuda.empty_cache()
print("MatAnyone moved to CPU successfully")
except Exception as e:
print(f"Error moving MatAnyone to CPU: {e}")

def move_matting_to_gpu():
"""Move MatAnyone model back to GPU for inference"""
global matanyone_model
print("Moving MatAnyone to GPU...")
try:
matanyone_model.to(args.device)
print("MatAnyone moved to GPU successfully")
except Exception as e:
print(f"Error moving MatAnyone to GPU: {e}")

def play_completion_sound():
"""Play terminal beep sound to indicate matting completion"""
try:
print("\a") # * Terminal bell character
print("🔔 Matting completed!")
except:
print("🔔 Matting completed!")

def parse_augment():
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default=None)
Expand Down Expand Up @@ -269,6 +321,10 @@ def show_mask(video_state, interactive_state, mask_dropdown):

# image matting
def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, refine_iter):
# * Move SAM to CPU and MatAnyone to GPU for matting
move_sam_to_cpu()
move_matting_to_gpu()

matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
if interactive_state["track_end_number"]:
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
Expand All @@ -294,10 +350,19 @@ def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_si
foreground_output = Image.fromarray(foreground[-1])
alpha_output = Image.fromarray(alpha[-1][:,:,0])

# * Move MatAnyone to CPU and SAM back to GPU after matting
move_matting_to_cpu()
move_sam_to_gpu()
play_completion_sound()

return foreground_output, alpha_output

# video matting
def video_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size):
# * Move SAM to CPU and MatAnyone to GPU for matting
move_sam_to_cpu()
move_matting_to_gpu()

matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
if interactive_state["track_end_number"]:
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
Expand Down Expand Up @@ -327,6 +392,11 @@ def video_matting(video_state, interactive_state, mask_dropdown, erode_kernel_si
foreground_output = generate_video_from_frames(foreground, output_path="./results/{}_fg.mp4".format(video_state["video_name"]), fps=fps, audio_path=audio_path) # import video_input to name the output video
alpha_output = generate_video_from_frames(alpha, output_path="./results/{}_alpha.mp4".format(video_state["video_name"]), fps=fps, gray2rgb=True, audio_path=audio_path) # import video_input to name the output video

# * Move MatAnyone to CPU and SAM back to GPU after matting
move_matting_to_cpu()
move_sam_to_gpu()
play_completion_sound()

return foreground_output, alpha_output


Expand Down Expand Up @@ -421,7 +491,8 @@ def restart():
pretrain_model_url = "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0/matanyone.pth"
ckpt_path = load_file_from_url(pretrain_model_url, checkpoint_folder)
matanyone_model = get_matanyone_model(ckpt_path, args.device)
matanyone_model = matanyone_model.to(args.device).eval()
# * Start with MatAnyone on CPU to save GPU memory for SAM
matanyone_model = matanyone_model.cpu().eval()
# matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)

# download test samples
Expand Down