From 0baecb5f09da96bf5aaefbe11eb1739ee34e2bf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Mon, 15 Jul 2024 15:33:42 +0800 Subject: [PATCH] Add mobile_sam with controlnet_aux (#3000) * Add mobile_sam with controlnet_aux for CNXL_Union --- annotator/mobile_sam/__init__.py | 49 ++++++++++++++++++++++++++++++ requirements.txt | 1 + scripts/preprocessor/__init__.py | 3 +- scripts/preprocessor/mobile_sam.py | 25 +++++++++++++++ 4 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 annotator/mobile_sam/__init__.py create mode 100644 scripts/preprocessor/mobile_sam.py diff --git a/annotator/mobile_sam/__init__.py b/annotator/mobile_sam/__init__.py new file mode 100644 index 000000000..57b4124f8 --- /dev/null +++ b/annotator/mobile_sam/__init__.py @@ -0,0 +1,49 @@ +from __future__ import print_function + +import os +import numpy as np +from PIL import Image +from typing import Union + +from modules import devices +from annotator.util import load_model +from annotator.annotator_path import models_path + +from controlnet_aux import SamDetector +from controlnet_aux.segment_anything import sam_model_registry, SamAutomaticMaskGenerator + +class SamDetector_Aux(SamDetector): + + model_dir = os.path.join(models_path, "mobile_sam") + + def __init__(self, mask_generator: SamAutomaticMaskGenerator, sam): + super().__init__(mask_generator) + self.device = devices.device + self.model = sam.to(self.device).eval() + + @classmethod + def from_pretrained(cls): + """ + Possible model_type : vit_h, vit_l, vit_b, vit_t + download weights from https://huggingface.co/dhkim2810/MobileSAM + """ + remote_url = os.environ.get( + "CONTROLNET_MOBILE_SAM_MODEL_URL", + "https://huggingface.co/dhkim2810/MobileSAM/resolve/main/mobile_sam.pt", + ) + model_path = load_model( + "mobile_sam.pt", remote_url=remote_url, model_dir=cls.model_dir + ) + + sam = sam_model_registry["vit_t"](checkpoint=model_path) + + cls.model = sam.to(devices.device).eval() + + mask_generator = SamAutomaticMaskGenerator(cls.model) + + return cls(mask_generator, sam) + + def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="cv2", **kwargs) -> np.ndarray: + self.model.to(self.device) + image = super().__call__(input_image=input_image, detect_resolution=detect_resolution, image_resolution=image_resolution, output_type=output_type, **kwargs) + return np.array(image).astype(np.uint8) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 10013834c..fef12cf3b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ matplotlib facexlib timm<=0.9.5 pydantic<=1.10.17 +controlnet_aux \ No newline at end of file diff --git a/scripts/preprocessor/__init__.py b/scripts/preprocessor/__init__.py index 081af9779..75cbab682 100644 --- a/scripts/preprocessor/__init__.py +++ b/scripts/preprocessor/__init__.py @@ -5,4 +5,5 @@ from .ip_adapter_auto import * from .normal_dsine import * from .model_free_preprocessors import * -from .legacy.legacy_preprocessors import * \ No newline at end of file +from .legacy.legacy_preprocessors import * +from .mobile_sam import * \ No newline at end of file diff --git a/scripts/preprocessor/mobile_sam.py b/scripts/preprocessor/mobile_sam.py new file mode 100644 index 000000000..e394647c0 --- /dev/null +++ b/scripts/preprocessor/mobile_sam.py @@ -0,0 +1,25 @@ +from annotator.mobile_sam import SamDetector_Aux +from scripts.supported_preprocessor import Preprocessor + +class PreprocessorMobileSam(Preprocessor): + def __init__(self): + super().__init__(name="mobile_sam") + self.tags = ["Segmentation"] + self.model = None + + def __call__( + self, + input_image, + resolution, + slider_1=None, + slider_2=None, + slider_3=None, + **kwargs + ): + if self.model is None: + self.model = SamDetector_Aux.from_pretrained() + + result = self.model(input_image, detect_resolution=resolution, image_resolution=resolution, output_type="cv2") + return result + +Preprocessor.add_supported_preprocessor(PreprocessorMobileSam()) \ No newline at end of file