diff --git a/examples/sam2_amg_server/README.md b/examples/sam2_amg_server/README.md index 730be293cd..43fc2b2528 100644 --- a/examples/sam2_amg_server/README.md +++ b/examples/sam2_amg_server/README.md @@ -8,7 +8,7 @@ curl -X POST http://127.0.0.1:5000/upload -F 'image=@/path/to/file.jpg' --output Start the server ``` -python server.py ~/checkpoints/sam2 --port --host --fast +python server.py ~/checkpoints/sam2 large --port --host --fast ``` Collect the rles @@ -58,7 +58,7 @@ Make sure you've installed https://github.com/facebookresearch/sam2 Start server ``` -python server.py ~/checkpoints/sam2 --port --host --baseline +python server.py ~/checkpoints/sam2 large --port --host --baseline ``` Generate and save rles (one line per json via `-w "\n"`) @@ -73,7 +73,7 @@ sys 0m4.137s ### 3. Start server with torchao variant of SAM2 Start server ``` -python server.py ~/checkpoints/sam2 --port --host +python server.py ~/checkpoints/sam2 large --port --host ``` Generate and save rles (one line per json via `-w "\n"`) @@ -88,7 +88,7 @@ sys 0m4.350s ### 4. Start server with torchao variant of SAM2 and `--fast` optimizations Start server ``` -python server.py ~/checkpoints/sam2 --port --host --fast +python server.py ~/checkpoints/sam2 large --port --host --fast ``` Generate and save rles (one line per json via `-w "\n"`) @@ -103,7 +103,7 @@ sys 0m4.138s ### 5. Start server with torchao variant of SAM2 and `--fast` and `--furious` optimizations Start server ``` -python server.py ~/checkpoints/sam2 --port --host --fast --furious +python server.py ~/checkpoints/sam2 large --port --host --fast --furious ``` Generate and save rles (one line per json via `-w "\n"`) diff --git a/examples/sam2_amg_server/cli.py b/examples/sam2_amg_server/cli.py new file mode 100644 index 0000000000..265f8c7b73 --- /dev/null +++ b/examples/sam2_amg_server/cli.py @@ -0,0 +1,48 @@ +import fire +import logging +import matplotlib.pyplot as plt +from server import file_bytes_to_image_tensor +from server import show_anns +from server import model_type_to_paths +from server import MODEL_TYPES_TO_MODEL +from torchao._models.sam2.build_sam import build_sam2 +from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator +from torchao._models.sam2.utils.amg import rle_to_mask +from io import BytesIO + +def main_docstring(): + return f""" + Args: + checkpoint_path (str): Path to folder containing checkpoints from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints + model_type (str): Choose from one of {", ".join(MODEL_TYPES_TO_MODEL.keys())} + input_path (str): Path to input image + output_path (str): Path to output image + """ + +def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False): + device = "cuda" + sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type) + if verbose: + print(f"Loading model {sam2_checkpoint} with config {model_cfg}") + sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) + mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle") + image_tensor = file_bytes_to_image_tensor(bytearray(open(input_path, 'rb').read())) + if verbose: + print(f"Loaded image of size {tuple(image_tensor.shape)} and generating mask.") + masks = mask_generator.generate(image_tensor) + + # Save an example + plt.figure(figsize=(image_tensor.shape[1]/100., image_tensor.shape[0]/100.), dpi=100) + plt.imshow(image_tensor) + show_anns(masks, rle_to_mask) + plt.axis('off') + plt.tight_layout() + buf = BytesIO() + plt.savefig(buf, format=output_format) + buf.seek(0) + with open(output_path, "wb") as file: + file.write(buf.getvalue()) + +main.__doc__ = main_docstring() +if __name__ == "__main__": + fire.Fire(main) diff --git a/examples/sam2_amg_server/requirements.txt b/examples/sam2_amg_server/requirements.txt index a77773d62b..e591e89100 100644 --- a/examples/sam2_amg_server/requirements.txt +++ b/examples/sam2_amg_server/requirements.txt @@ -7,3 +7,4 @@ hydra-core tqdm iopath python-multipart +requests diff --git a/examples/sam2_amg_server/server.py b/examples/sam2_amg_server/server.py index 1eafde8ee5..cbf916c2aa 100644 --- a/examples/sam2_amg_server/server.py +++ b/examples/sam2_amg_server/server.py @@ -1,4 +1,5 @@ import itertools +import requests import uvicorn import fire import tempfile @@ -37,6 +38,23 @@ # torch._dynamo.config.capture_dynamic_output_shape_ops = True torch._dynamo.config.capture_dynamic_output_shape_ops = True +def download_file(url, download_dir): + # Create the directory if it doesn't exist + download_dir = Path(download_dir) + download_dir.mkdir(parents=True, exist_ok=True) + # Extract the file name from the URL + file_name = url.split('/')[-1] + # Define the full path for the downloaded file + file_path = download_dir / file_name + # Download the file + response = requests.get(url, stream=True) + response.raise_for_status() # Raise an error for bad responses + # Write the file to the specified directory + print(f"Downloading '{file_name}' to '{download_dir}'") + with open(file_path, 'wb') as file: + for chunk in response.iter_content(chunk_size=8192): + file.write(chunk) + print(f"Downloaded '{file_name}' to '{download_dir}'") def example_shapes(): return [(848, 480, 3), @@ -272,7 +290,51 @@ def unittest_fn(masks, ref_masks, order_by_area=False, verbose=False): print(f"mIoU is {miou} with equal count {equal_count} out of {len(masks)}") +MODEL_TYPES_TO_CONFIG = { + "tiny": "sam2.1_hiera_t.yaml", + "small": "sam2.1_hiera_s.yaml", + "plus": "sam2.1_hiera_b+.yaml", + "large": "sam2.1_hiera_l.yaml", + } + +MODEL_TYPES_TO_MODEL = { + "tiny": "sam2.1_hiera_tiny.pt", + "small": "sam2.1_hiera_small.pt", + "plus": "sam2.1_hiera_base_plus.pt", + "large": "sam2.1_hiera_large.pt", + } + + +MODEL_TYPES_TO_URL = { + "tiny": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt", + "small": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt", + "plus": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt", + "large": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt", + } + + +def main_docstring(): + return f""" + Args: + checkpoint_path (str): Path to folder containing checkpoints from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints + model_type (str): Choose from one of {", ".join(MODEL_TYPES_TO_MODEL.keys())} + """ + + +def model_type_to_paths(checkpoint_path, model_type): + if model_type not in MODEL_TYPES_TO_CONFIG.keys(): + raise ValueError(f"Expected model_type to be one of {', '.join(MODEL_TYPES_TO_MODEL.keys())} but got {model_type}") + sam2_checkpoint = Path(checkpoint_path) / Path(MODEL_TYPES_TO_MODEL[model_type]) + if not sam2_checkpoint.exists(): + print(f"Can't find checkpoint {sam2_checkpoint} in folder {checkpoint_path}. Downloading.") + download_file(MODEL_TYPES_TO_URL[model_type], checkpoint_path) + assert sam2_checkpoint.exists(), "Can't find downloaded file. Please open an issue." + model_cfg = f"configs/sam2.1/{MODEL_TYPES_TO_CONFIG[model_type]}" + return sam2_checkpoint, model_cfg + + def main(checkpoint_path, + model_type, baseline=False, fast=False, furious=False, @@ -306,9 +368,7 @@ def main(checkpoint_path, from torchao._models.sam2.utils.amg import rle_to_mask device = "cuda" - from pathlib import Path - sam2_checkpoint = Path(checkpoint_path) / Path("sam2.1_hiera_large.pt") - model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" + sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type) logging.info(f"Loading model {sam2_checkpoint} with config {model_cfg}") sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) @@ -450,5 +510,6 @@ async def upload_image(image: UploadFile = File(...)): # uvicorn.run(app, host=host, port=port, log_level="info") uvicorn.run(app, host=host, port=port) +main.__doc__ = main_docstring() if __name__ == "__main__": fire.Fire(main) diff --git a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml b/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml index cbee3cf9b3..42cd897c67 100644 --- a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +++ b/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml @@ -2,18 +2,18 @@ # Model model: - _target_: sam2.modeling.sam2_base.SAM2Base + _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: sam2.modeling.backbones.hieradet.Hiera + _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 112 num_heads: 2 neck: - _target_: sam2.modeling.backbones.image_encoder.FpnNeck + _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -24,17 +24,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: sam2.modeling.memory_attention.MemoryAttention + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -45,7 +45,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -57,23 +57,23 @@ model: num_layers: 4 memory_encoder: - _target_: sam2.modeling.memory_encoder.MemoryEncoder + _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: sam2.modeling.memory_encoder.MaskDownSampler + _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: sam2.modeling.memory_encoder.Fuser + _target_: torchao._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: sam2.modeling.memory_encoder.CXBlock + _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml b/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml index 8e803dfea5..898898b158 100644 --- a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml +++ b/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml @@ -2,21 +2,21 @@ # Model model: - _target_: sam2.modeling.sam2_base.SAM2Base + _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: sam2.modeling.backbones.hieradet.Hiera + _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 11, 2] global_att_blocks: [7, 10, 13] window_pos_embed_bkg_spatial_size: [7, 7] neck: - _target_: sam2.modeling.backbones.image_encoder.FpnNeck + _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -27,17 +27,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: sam2.modeling.memory_attention.MemoryAttention + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -48,7 +48,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -60,23 +60,23 @@ model: num_layers: 4 memory_encoder: - _target_: sam2.modeling.memory_encoder.MemoryEncoder + _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: sam2.modeling.memory_encoder.MaskDownSampler + _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: sam2.modeling.memory_encoder.Fuser + _target_: torchao._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: sam2.modeling.memory_encoder.CXBlock + _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_t.yaml b/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_t.yaml index 983c2ea031..c6318f843b 100644 --- a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_t.yaml +++ b/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_t.yaml @@ -2,21 +2,21 @@ # Model model: - _target_: sam2.modeling.sam2_base.SAM2Base + _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: sam2.modeling.backbones.hieradet.Hiera + _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 7, 2] global_att_blocks: [5, 7, 9] window_pos_embed_bkg_spatial_size: [7, 7] neck: - _target_: sam2.modeling.backbones.image_encoder.FpnNeck + _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -27,17 +27,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: sam2.modeling.memory_attention.MemoryAttention + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -48,7 +48,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -60,23 +60,23 @@ model: num_layers: 4 memory_encoder: - _target_: sam2.modeling.memory_encoder.MemoryEncoder + _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: sam2.modeling.memory_encoder.MaskDownSampler + _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: sam2.modeling.memory_encoder.Fuser + _target_: torchao._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: sam2.modeling.memory_encoder.CXBlock + _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/utils/amg.py b/torchao/_models/sam2/utils/amg.py index 5335ff0609..cf52cae327 100644 --- a/torchao/_models/sam2/utils/amg.py +++ b/torchao/_models/sam2/utils/amg.py @@ -225,16 +225,15 @@ def _mask_to_rle_pytorch_2_0_0(tensor: torch.Tensor) -> (torch.Tensor, torch.Ten b, h, w = tensor.shape tensor = tensor.permute(0, 2, 1).flatten(1) - with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: change indices"): - # Compute change indices - diff = tensor[:, 1:] ^ tensor[:, :-1] - # a = torch.tensor([[True]]) - a = torch.ones((1, 1), dtype=bool, device=diff.device) - # if diff.is_cuda: - # a = a.pin_memory().cuda() - # # a = a.to(diff.device) - a = a.expand_as(diff.narrow(1, 0, 1)) - diff = torch.cat([a, diff, a], dim=1) + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + # a = torch.tensor([[True]]) + a = torch.ones((1, 1), dtype=bool, device=diff.device) + # if diff.is_cuda: + # a = a.pin_memory().cuda() + # # a = a.to(diff.device) + a = a.expand_as(diff.narrow(1, 0, 1)) + diff = torch.cat([a, diff, a], dim=1) return diff @@ -242,19 +241,21 @@ def _mask_to_rle_pytorch_2_0_0(tensor: torch.Tensor) -> (torch.Tensor, torch.Ten def _mask_to_rle_pytorch_2_0_1(tensor: torch.Tensor, diff: torch.Tensor, change_indices: torch.Tensor) -> (torch.Tensor, torch.Tensor): tensor = tensor.permute(0, 2, 1).flatten(1) - with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: all_btw_idx"): - alt_lens = diff.sum(dim=1) + alt_lens = diff.sum(dim=1) - all_cur_idx = change_indices[:, 1] + all_cur_idx = change_indices[:, 1] + if all_cur_idx.numel() == 0: + all_cur_idx_0 = all_cur_idx + all_cur_idx_1 = all_cur_idx + else: all_cur_idx_0 = all_cur_idx.narrow(0, 1, all_cur_idx.size(0) - 1) all_cur_idx_1 = all_cur_idx.narrow(0, 0, 1) - all_btw_idx = torch.cat([all_cur_idx_0, all_cur_idx_1]) - all_btw_idx = all_btw_idx - all_cur_idx + all_btw_idx = torch.cat([all_cur_idx_0, all_cur_idx_1]) + all_btw_idx = all_btw_idx - all_cur_idx - with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: Encode run length"): - alt_lens_nt = torch.nested.nested_tensor_from_jagged(all_btw_idx, lengths=alt_lens) - # Encode run length - counts_init = (tensor[:, 0] == 0) + alt_lens_nt = torch.nested.nested_tensor_from_jagged(all_btw_idx, lengths=alt_lens) + # Encode run length + counts_init = (tensor[:, 0] == 0) return alt_lens_nt, counts_init