Skip to content

Commit

Permalink
[feature] text_to_sam_clip (PaddlePaddle#3187)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sunting78 authored Apr 17, 2023
1 parent 2319feb commit fdbb025
Show file tree
Hide file tree
Showing 10 changed files with 798 additions and 36 deletions.
55 changes: 30 additions & 25 deletions contrib/SegmentAnything/README.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
# Segment Anything with PaddleSeg

## Reference

> Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alexander C. Berg, Wan-Yen Lo, Piotr Dollár, Ross Girshick. [Segment Anything](https://ai.facebook.com/research/publications/segment-anything/).

## Contents
1. Overview
2. Performance
3. Try it by yourself with one line of code
4. Reference


## <img src="https://user-images.githubusercontent.com/34859558/190043857-bfbdaf8b-d2dc-4fff-81c7-e0aac50851f9.png" width="25"/> Overview

We implemente the segment anything with the PaddlePaddle framework. **Segment Anything Model (SAM)** is a new task, model, and dataset for image segmentation. It can produce high quality object masks from different types of prompts including points, boxes, masks and text. Further, SAM can generate masks for all objects in whole image. It built a largest segmentation [dataset](https://segment-anything.com/dataset/index.html) to date (by far), with over 1 billion masks on 11M licensed and privacy respecting images. SAM has impressive zero-shot performance on a variety of tasks, even often competitive with or even superior to prior fully supervised results.
We implemente the segment anything with the PaddlePaddle framework. **Segment Anything Model (SAM)** is a new task, model, and dataset for image segmentation. It built a largest segmentation [dataset](https://segment-anything.com/dataset/index.html) to date (by far), with over 1 billion masks on 11M licensed and privacy respecting images. Further, SAM can produce high quality object masks from different types of prompts including points, boxes, masks and text. SAM has impressive zero-shot performance on a variety of tasks, even often competitive with or even superior to prior fully supervised results. However, the SAM model based on text prompt is not released at the moment. Therefore, we use a combination of **SAM** and **CLIP** to calculate the similarity between the output masks and text prompt. In this way, you can use **text prompt** to segment anything. In addition, we also implement SAM that can generate masks for all objects in whole image.


We provide the pretrained model parameters of PaddlePaddle format, including [vit_b](https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams), [vit_l](https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams) and [vit_h](https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_h/model.pdparams).
We provide the pretrained model parameters of PaddlePaddle format, including [vit_b](https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams), [vit_l](https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams) and [vit_h](https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_h/model.pdparams). For text prompt, we also provide the [CLIP_ViT_B](https://bj.bcebos.com/paddleseg/dygraph/clip/vit_b_32_pretrain/clip_vit_b_32.pdparams) model parameters of PaddlePaddle format.

## <img src="https://user-images.githubusercontent.com/34859558/190044217-8f6befc2-7f20-473d-b356-148e06265205.png" width="25"/> Performance

<div align="center">
<img src="https://github.com/Sunting78/images/blob/master/sam_new.gif" width="1000" />
<img src="https://user-images.githubusercontent.com/18344247/232466911-f8d1c016-2eb2-46aa-94e2-3ec435f38502.gif" width="1000" />
</div>


Expand All @@ -33,44 +32,51 @@ We provide the pretrained model parameters of PaddlePaddle format, including [vi
git clone https://github.com/PaddlePaddle/PaddleSeg.git
cd PaddleSeg
pip install -r requirements.txt
pip install ftfy regex
cd contrib/SegmentAnything/
```
* Download the example image to ```contrib/SegmentAnything/examples```, and the file structure is as following:
* Download the example image to ```contrib/SegmentAnything/examples``` and the vocab to ```contrib/SegmentAnything/```
```bash
wget https://paddleseg.bj.bcebos.com/dygraph/demo/cityscapes_demo.png
wget https://bj.bcebos.com/paddleseg/dygraph/bpe_vocab_16e6/bpe_simple_vocab_16e6.txt.gz
```
Then, the file structure is as following:

```
PaddleSeg/contrib
├── SegmentAnything
│ ├── examples
│ │ └── cityscapes_demo.png
│ ├── segment_anything
│ └── scripts
│ ├── scripts
│ └── bpe_simple_vocab_16e6.txt.gz

```
### 2. Segment Anything on webpage.

### 2. Segment the whole image on webpage.
In this step, we start a gradio service with the following scrip on local machine and you can try out our project with your own images.
Based on this service, You can experience the ability to **segment the whole image** and **segment the object based on text prompts**.

1. Run the following script:
```bash
python scripts/amg_paddle.py --model-type [vit_l/vit_b/vit_h] # default is vit_h
python scripts/text_to_sam_clip.py --model-type [vit_l/vit_b/vit_h] # default is vit_h
```
Note:
* There are three model options for you, vit_b, vit_l and vit_h, represent vit_base, vit_large and vit_huge. Large model is more accurate and also slower. You can choose the model size based on your device.
* The test result shows that vit_h needs 16G video memory and needs around 10s to infer an image on V100.

2. Open the webpage on your localhost: ```http://0.0.0.0:8017```
* There are three SAM model options for you, `vit_b`, `vit_l` and `vit_h`, represent vit_base, vit_large and vit_huge. Large model is more accurate but slower. You can choose the suitable model size based on your device.
* We support `CLIP Vit-B` model for extracting text and image features.
* `SAM vit_h` needs 16G memory and costs around 10s to infer an image on V100.

2. Open the webpage on your localhost: ```http://0.0.0.0:8078```
3. Try it out by clear and upload the test image! Our example looks like:

<div align="center">
<img src="https://user-images.githubusercontent.com/34859558/230873989-9597527e-bef6-47ce-988b-977198794d75.jpg" width = "1000" />
<img src="https://user-images.githubusercontent.com/18344247/232427677-a7f913df-4abf-46ce-be2c-e37cbd495105.png" width = "1000" />
</div>

### 3. Segment the object with prompts
You can run the following commands to produce masks from different types of prompts including points, boxes, and masks, as follow:

### 3. Segment the object with point or box prompts

You can run the following commands to produce masks from different types of prompts including points and boxes, as follow:


1. Box prompt
Expand All @@ -84,10 +90,9 @@ python scripts/promt_predict.py --input_path xxx.png --box_prompt 1050 370 1500
python scripts/promt_predict.py --input_path xxx.png --point_prompt 1200 450 --model-type [vit_l/vit_b/vit_h] # default is vit_h
```

3. Mask prompt
```bash
python scripts/promt_predict.py --input_path xxx.png --mask_prompt xxx.png --model-type [vit_l/vit_b/vit_h] # default is vit_h
```

Note:
* mask_prompt is the path of a binary image.
## Reference

> Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alexander C. Berg, Wan-Yen Lo, Piotr Dollár, Ross Girshick. [Segment Anything](https://ai.facebook.com/research/publications/segment-anything/).

> Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever Proceedings of the 38th International Conference on Machine Learning, PMLR 139:8748-8763, 2021. [CLIP](https://github.com/openai/CLIP)
Binary file added contrib/SegmentAnything/examples/dog.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contrib/SegmentAnything/examples/zixingche.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 2 additions & 9 deletions contrib/SegmentAnything/scripts/promt_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

def get_args():
parser = argparse.ArgumentParser(
description='Segment image with point promp, box or mask')
description='Segment image with point promp or box')
# Parameters
parser.add_argument(
'--input_path', type=str, required=True, help='The directory of image.')
Expand All @@ -61,8 +61,6 @@ def get_args():
nargs='+',
default=None,
help='box promt format as xyxy.')
parser.add_argument(
'--mask_prompt', type=str, default=None, help='The path of mask.')
parser.add_argument(
'--output_path',
type=str,
Expand All @@ -88,18 +86,14 @@ def main(args):
paddle.set_device("cpu")
input_path = args.input_path
output_path = args.output_path
point, box, mask_path = args.point_prompt, args.box_prompt, args.mask_prompt
point, box = args.point_prompt, args.box_prompt
if point is not None:
point = np.array([point])
input_label = np.array([1])
else:
input_label = None
if box is not None:
box = np.array([[box[0], box[1]], [box[2], box[3]]])
if mask_path is not None:
mask = cv2.imread(mask_path, -1)
else:
mask = None

image = cv2.imread(input_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
Expand All @@ -112,7 +106,6 @@ def main(args):
point_coords=point,
point_labels=input_label,
box=box,
mask_input=mask,
multimask_output=True, )

plt.figure(figsize=(10, 10))
Expand Down
239 changes: 239 additions & 0 deletions contrib/SegmentAnything/scripts/text_to_sam_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import cv2
import time
import sys
import argparse
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), ".."))

import paddle
import paddle.nn.functional as F
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from segment_anything.modeling.clip_paddle import build_clip_model, _transform
from segment_anything.utils.sample_tokenizer import tokenize
from paddleseg.utils.visualize import get_pseudo_color_map, get_color_map_list

ID_PHOTO_IMAGE_DEMO = "./examples/cityscapes_demo.png"
CACHE_DIR = ".temp"
model_link = {
'vit_h':
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_h/model.pdparams",
'vit_l':
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams",
'vit_b':
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams",
'clip_b_32':
"https://bj.bcebos.com/paddleseg/dygraph/clip/vit_b_32_pretrain/clip_vit_b_32.pdparams"
}

parser = argparse.ArgumentParser(description=(
"Runs automatic mask generation on an input image or directory of images, "
"and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
"as well as pycocotools if saving in RLE format."))

parser.add_argument(
"--model-type",
type=str,
default="vit_h",
required=True,
help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b']", )


def download(img):
if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR)
while True:
name = str(int(time.time()))
tmp_name = os.path.join(CACHE_DIR, name + '.jpg')
if not os.path.exists(tmp_name):
break
else:
time.sleep(1)
img.save(tmp_name, 'png')
return tmp_name


def segment_image(image, segment_mask):
image_array = np.array(image)
gray_image = Image.new("RGB", image.size, (128, 128, 128))
segmented_image_array = np.zeros_like(image_array)
segmented_image_array[segment_mask] = image_array[segment_mask]
segmented_image = Image.fromarray(segmented_image_array)
transparency = np.zeros_like(segment_mask, dtype=np.uint8)
transparency[segment_mask] = 255
transparency_image = Image.fromarray(transparency, mode='L')
gray_image.paste(segmented_image, mask=transparency_image)
return gray_image


def image_text_match(cropped_objects, text_query):
transformed_images = [transform(image) for image in cropped_objects]
tokenized_text = tokenize([text_query])
batch_images = paddle.stack(transformed_images)
image_features = model.encode_image(batch_images)
print("encode_image done!")
text_features = model.encode_text(tokenized_text)
print("encode_text done!")
image_features /= image_features.norm(axis=-1, keepdim=True)
text_features /= text_features.norm(axis=-1, keepdim=True)
probs = 100. * image_features @text_features.T
return F.softmax(probs[:, 0], axis=0)


def masks2pseudomap(masks):
result = np.ones(masks[0]["segmentation"].shape, dtype=np.uint8) * 255
for i, mask_data in enumerate(masks):
result[mask_data["segmentation"] == 1] = i + 1
pred_result = result
result = get_pseudo_color_map(result)
return pred_result, result


def visualize(image, result, color_map, weight=0.6):
"""
Convert predict result to color image, and save added image.
Args:
image (str): The path of origin image.
result (np.ndarray): The predict result of image.
color_map (list): The color used to save the prediction results.
save_dir (str): The directory for saving visual image. Default: None.
weight (float): The image weight of visual image, and the result weight is (1 - weight). Default: 0.6
Returns:
vis_result (np.ndarray): If `save_dir` is None, return the visualized result.
"""

color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
color_map = np.array(color_map).astype("uint8")
# Use OpenCV LUT for color mapping
c1 = cv2.LUT(result, color_map[:, 0])
c2 = cv2.LUT(result, color_map[:, 1])
c3 = cv2.LUT(result, color_map[:, 2])
pseudo_img = np.dstack((c3, c2, c1))

vis_result = cv2.addWeighted(image, weight, pseudo_img, 1 - weight, 0)
return vis_result


def get_id_photo_output(image, text):
"""
Get the special size and background photo.
Args:
img(numpy:ndarray): The image array.
size(str): The size user specified.
bg(str): The background color user specified.
download_size(str): The size for image saving.
"""
image_ori = image.copy()
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = mask_generator.generate(image)
pred_result, pseudo_map = masks2pseudomap(masks) # PIL Image
added_pseudo_map = visualize(
image, pred_result, color_map=get_color_map_list(256))
cropped_objects = []
image_pil = Image.fromarray(image)
for mask in masks:
bbox = [
mask["bbox"][0], mask["bbox"][1], mask["bbox"][0] + mask["bbox"][2],
mask["bbox"][1] + mask["bbox"][3]
]
cropped_objects.append(
segment_image(image_pil, mask["segmentation"]).crop(bbox))

scores = image_text_match(cropped_objects, str(text))
text_matching_masks = []
for idx, score in enumerate(scores):
if score < 0.05:
continue
text_matching_mask = Image.fromarray(
masks[idx]["segmentation"].astype('uint8') * 255)
text_matching_masks.append(text_matching_mask)

image_pil_ori = Image.fromarray(image_ori)
alpha_image = Image.new('RGBA', image_pil_ori.size, (0, 0, 0, 0))
alpha_color = (255, 0, 0, 180)

draw = ImageDraw.Draw(alpha_image)
for text_matching_mask in text_matching_masks:
draw.bitmap((0, 0), text_matching_mask, fill=alpha_color)

result_image = Image.alpha_composite(
image_pil_ori.convert('RGBA'), alpha_image)
res_download = download(result_image)
return result_image, added_pseudo_map, res_download


def gradio_display():
import gradio as gr
examples_sam = [["./examples/cityscapes_demo.png", "a photo of car"],
["examples/dog.jpg", "dog"],
["examples/zixingche.jpeg", "kid"]]

demo_mask_sam = gr.Interface(
fn=get_id_photo_output,
inputs=[
gr.Image(
value=ID_PHOTO_IMAGE_DEMO,
label="Input image").style(height=400), gr.inputs.Textbox(
lines=3,
placeholder=None,
default="a photo of car",
label='🔥 Input text prompt 🔥',
optional=False)
],
outputs=[
gr.Image(
label="Output based on text",
interactive=False).style(height=300), gr.Image(
label="Output mask", interactive=False).style(height=300)
],
examples=examples_sam,
description="<p> \
<strong>SAM+CLIP: Text prompt for segmentation. </strong> <br>\
Choose an example below; Or, upload by yourself: <br>\
1. Upload images to be tested to 'input image'. 2. Input a text prompt to 'input text prompt' and click 'submit'</strong>. <br>\
</p>",
cache_examples=False,
allow_flagging="never", )

demo = gr.TabbedInterface(
[demo_mask_sam, ], ['SAM+CLIP(Text to Segment)'],
title=" 🔥 Text to Segment Anything with PaddleSeg 🔥")
demo.launch(
server_name="0.0.0.0", enable_queue=False, server_port=8078, share=True)


args = parser.parse_args()
print("Loading model...")

if paddle.is_compiled_with_cuda():
paddle.set_device("gpu")
else:
paddle.set_device("cpu")

sam = sam_model_registry[args.model_type](
checkpoint=model_link[args.model_type])
mask_generator = SamAutomaticMaskGenerator(sam)

model, transform = build_clip_model(model_link["clip_b_32"])
gradio_display()
Loading

0 comments on commit fdbb025

Please sign in to comment.