Skip to content

Commit

Permalink
SAM2 AMG cli and other QoL improvements (#1336)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored Nov 23, 2024
1 parent 51c87b6 commit b2e42ff
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 69 deletions.
10 changes: 5 additions & 5 deletions examples/sam2_amg_server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ curl -X POST http://127.0.0.1:5000/upload -F 'image=@/path/to/file.jpg' --output
Start the server

```
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --fast
python server.py ~/checkpoints/sam2 large --port <your_port> --host <your_hostname> --fast
```

Collect the rles
Expand Down Expand Up @@ -58,7 +58,7 @@ Make sure you've installed https://github.com/facebookresearch/sam2

Start server
```
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --baseline
python server.py ~/checkpoints/sam2 large --port <your_port> --host <your_hostname> --baseline
```

Generate and save rles (one line per json via `-w "\n"`)
Expand All @@ -73,7 +73,7 @@ sys 0m4.137s
### 3. Start server with torchao variant of SAM2
Start server
```
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname>
python server.py ~/checkpoints/sam2 large --port <your_port> --host <your_hostname>
```

Generate and save rles (one line per json via `-w "\n"`)
Expand All @@ -88,7 +88,7 @@ sys 0m4.350s
### 4. Start server with torchao variant of SAM2 and `--fast` optimizations
Start server
```
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --fast
python server.py ~/checkpoints/sam2 large --port <your_port> --host <your_hostname> --fast
```

Generate and save rles (one line per json via `-w "\n"`)
Expand All @@ -103,7 +103,7 @@ sys 0m4.138s
### 5. Start server with torchao variant of SAM2 and `--fast` and `--furious` optimizations
Start server
```
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --fast --furious
python server.py ~/checkpoints/sam2 large --port <your_port> --host <your_hostname> --fast --furious
```

Generate and save rles (one line per json via `-w "\n"`)
Expand Down
48 changes: 48 additions & 0 deletions examples/sam2_amg_server/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import fire
import logging
import matplotlib.pyplot as plt
from server import file_bytes_to_image_tensor
from server import show_anns
from server import model_type_to_paths
from server import MODEL_TYPES_TO_MODEL
from torchao._models.sam2.build_sam import build_sam2
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from torchao._models.sam2.utils.amg import rle_to_mask
from io import BytesIO

def main_docstring():
return f"""
Args:
checkpoint_path (str): Path to folder containing checkpoints from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints
model_type (str): Choose from one of {", ".join(MODEL_TYPES_TO_MODEL.keys())}
input_path (str): Path to input image
output_path (str): Path to output image
"""

def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False):
device = "cuda"
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)
if verbose:
print(f"Loading model {sam2_checkpoint} with config {model_cfg}")
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")
image_tensor = file_bytes_to_image_tensor(bytearray(open(input_path, 'rb').read()))
if verbose:
print(f"Loaded image of size {tuple(image_tensor.shape)} and generating mask.")
masks = mask_generator.generate(image_tensor)

# Save an example
plt.figure(figsize=(image_tensor.shape[1]/100., image_tensor.shape[0]/100.), dpi=100)
plt.imshow(image_tensor)
show_anns(masks, rle_to_mask)
plt.axis('off')
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format=output_format)
buf.seek(0)
with open(output_path, "wb") as file:
file.write(buf.getvalue())

main.__doc__ = main_docstring()
if __name__ == "__main__":
fire.Fire(main)
1 change: 1 addition & 0 deletions examples/sam2_amg_server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ hydra-core
tqdm
iopath
python-multipart
requests
67 changes: 64 additions & 3 deletions examples/sam2_amg_server/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import requests
import uvicorn
import fire
import tempfile
Expand Down Expand Up @@ -37,6 +38,23 @@
# torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True

def download_file(url, download_dir):
# Create the directory if it doesn't exist
download_dir = Path(download_dir)
download_dir.mkdir(parents=True, exist_ok=True)
# Extract the file name from the URL
file_name = url.split('/')[-1]
# Define the full path for the downloaded file
file_path = download_dir / file_name
# Download the file
response = requests.get(url, stream=True)
response.raise_for_status() # Raise an error for bad responses
# Write the file to the specified directory
print(f"Downloading '{file_name}' to '{download_dir}'")
with open(file_path, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"Downloaded '{file_name}' to '{download_dir}'")

def example_shapes():
return [(848, 480, 3),
Expand Down Expand Up @@ -272,7 +290,51 @@ def unittest_fn(masks, ref_masks, order_by_area=False, verbose=False):
print(f"mIoU is {miou} with equal count {equal_count} out of {len(masks)}")


MODEL_TYPES_TO_CONFIG = {
"tiny": "sam2.1_hiera_t.yaml",
"small": "sam2.1_hiera_s.yaml",
"plus": "sam2.1_hiera_b+.yaml",
"large": "sam2.1_hiera_l.yaml",
}

MODEL_TYPES_TO_MODEL = {
"tiny": "sam2.1_hiera_tiny.pt",
"small": "sam2.1_hiera_small.pt",
"plus": "sam2.1_hiera_base_plus.pt",
"large": "sam2.1_hiera_large.pt",
}


MODEL_TYPES_TO_URL = {
"tiny": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
"small": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
"plus": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
"large": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
}


def main_docstring():
return f"""
Args:
checkpoint_path (str): Path to folder containing checkpoints from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints
model_type (str): Choose from one of {", ".join(MODEL_TYPES_TO_MODEL.keys())}
"""


def model_type_to_paths(checkpoint_path, model_type):
if model_type not in MODEL_TYPES_TO_CONFIG.keys():
raise ValueError(f"Expected model_type to be one of {', '.join(MODEL_TYPES_TO_MODEL.keys())} but got {model_type}")
sam2_checkpoint = Path(checkpoint_path) / Path(MODEL_TYPES_TO_MODEL[model_type])
if not sam2_checkpoint.exists():
print(f"Can't find checkpoint {sam2_checkpoint} in folder {checkpoint_path}. Downloading.")
download_file(MODEL_TYPES_TO_URL[model_type], checkpoint_path)
assert sam2_checkpoint.exists(), "Can't find downloaded file. Please open an issue."
model_cfg = f"configs/sam2.1/{MODEL_TYPES_TO_CONFIG[model_type]}"
return sam2_checkpoint, model_cfg


def main(checkpoint_path,
model_type,
baseline=False,
fast=False,
furious=False,
Expand Down Expand Up @@ -306,9 +368,7 @@ def main(checkpoint_path,
from torchao._models.sam2.utils.amg import rle_to_mask

device = "cuda"
from pathlib import Path
sam2_checkpoint = Path(checkpoint_path) / Path("sam2.1_hiera_large.pt")
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)

logging.info(f"Loading model {sam2_checkpoint} with config {model_cfg}")
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
Expand Down Expand Up @@ -450,5 +510,6 @@ async def upload_image(image: UploadFile = File(...)):
# uvicorn.run(app, host=host, port=port, log_level="info")
uvicorn.run(app, host=host, port=port)

main.__doc__ = main_docstring()
if __name__ == "__main__":
fire.Fire(main)
28 changes: 14 additions & 14 deletions torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@

# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
_target_: torchao._models.sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
_target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
_target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera
embed_dim: 112
num_heads: 2
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
_target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
_target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
Expand All @@ -24,17 +24,17 @@ model:
fpn_interp_model: nearest

memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
_target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
_target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
_target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
embedding_dim: 256
Expand All @@ -45,7 +45,7 @@ model:
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
_target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
rope_k_repeat: True
Expand All @@ -57,23 +57,23 @@ model:
num_layers: 4

memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
_target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
_target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
_target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
_target_: torchao._models.sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
_target_: torchao._models.sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
Expand Down
28 changes: 14 additions & 14 deletions torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@

# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
_target_: torchao._models.sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
_target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
_target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera
embed_dim: 96
num_heads: 1
stages: [1, 2, 11, 2]
global_att_blocks: [7, 10, 13]
window_pos_embed_bkg_spatial_size: [7, 7]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
_target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
_target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
Expand All @@ -27,17 +27,17 @@ model:
fpn_interp_model: nearest

memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
_target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
_target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
_target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
embedding_dim: 256
Expand All @@ -48,7 +48,7 @@ model:
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
_target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
rope_k_repeat: True
Expand All @@ -60,23 +60,23 @@ model:
num_layers: 4

memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
_target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
_target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
_target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
_target_: torchao._models.sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
_target_: torchao._models.sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
Expand Down
Loading

0 comments on commit b2e42ff

Please sign in to comment.