From 54c3f31a4675ef57f2ee467a6ce65851526102f5 Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Tue, 26 Nov 2024 01:26:50 -0800 Subject: [PATCH] Reduce SAM2 AMG cli startup by using deploy (#1350) --- examples/sam2_amg_server/cli.py | 24 +++- examples/sam2_amg_server/cli_on_modal.py | 175 ++++++++++++++++------- 2 files changed, 145 insertions(+), 54 deletions(-) diff --git a/examples/sam2_amg_server/cli.py b/examples/sam2_amg_server/cli.py index b8afcfc3c..2fead4b5a 100644 --- a/examples/sam2_amg_server/cli.py +++ b/examples/sam2_amg_server/cli.py @@ -21,7 +21,8 @@ def main_docstring(): output_path (str): Path to output image """ -def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False): + +def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False): device = "cuda" sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type) if verbose: @@ -32,12 +33,13 @@ def main(checkpoint_path, model_type, input_path, output_path, points_per_batch= set_fast(mask_generator) if furious: set_furious(mask_generator) - image_tensor = file_bytes_to_image_tensor(bytearray(open(input_path, 'rb').read())) + image_tensor = file_bytes_to_image_tensor(input_bytes) if verbose: print(f"Loaded image of size {tuple(image_tensor.shape)} and generating mask.") masks = mask_generator.generate(image_tensor) - - # Save an example + + if verbose: + print("Generating mask annotations for input image.") plt.figure(figsize=(image_tensor.shape[1]/100., image_tensor.shape[0]/100.), dpi=100) plt.imshow(image_tensor) show_anns(masks, rle_to_mask) @@ -46,8 +48,20 @@ def main(checkpoint_path, model_type, input_path, output_path, points_per_batch= buf = BytesIO() plt.savefig(buf, format=output_format) buf.seek(0) + return buf.getvalue() + +def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False): + input_bytes = bytearray(open(input_path, 'rb').read()) + output_bytes = main_headless(checkpoint_path, + model_type, + input_bytes, + points_per_batch=points_per_batch, + output_format=output_format, + verbose=verbose, + fast=fast, + furious=furious) with open(output_path, "wb") as file: - file.write(buf.getvalue()) + file.write(output_bytes) main.__doc__ = main_docstring() if __name__ == "__main__": diff --git a/examples/sam2_amg_server/cli_on_modal.py b/examples/sam2_amg_server/cli_on_modal.py index fdd6316b2..3295ede84 100644 --- a/examples/sam2_amg_server/cli_on_modal.py +++ b/examples/sam2_amg_server/cli_on_modal.py @@ -1,10 +1,11 @@ from pathlib import Path +import json +import fire import modal -app = modal.App("torchao-sam-2-cli") - TARGET = "/root/" +DOWNLOAD_URL_BASE = "https://raw.githubusercontent.com/pytorch/ao/refs/heads" image = ( modal.Image.debian_slim(python_version="3.12.7") @@ -34,61 +35,137 @@ ) ) -checkpoints = modal.Volume.from_name("checkpoints", create_if_missing=True) +app = modal.App("torchao-sam-2-cli", image=image) + +checkpoints = modal.Volume.from_name("torchao-sam-2-cli-checkpoints", create_if_missing=True) +data = modal.Volume.from_name("torchao-sam-2-cli-data", create_if_missing=True) + -@app.function( - image=image, +@app.cls( gpu="H100", + container_idle_timeout=20 * 60, + timeout=20 * 60, volumes={ TARGET + "checkpoints": checkpoints, - # # mount the caches of torch.compile and friends - # "/root/.nv": modal.Volume.from_name("torchao-sam-2-cli-nv-cache", create_if_missing=True), - # "/root/.triton": modal.Volume.from_name( - # "torchao-sam-2-cli-triton-cache", create_if_missing=True - # ), - # "/root/.inductor-cache": modal.Volume.from_name( - # "torchao-sam-2-cli-inductor-cache", create_if_missing=True - # ), + TARGET + "data": data, }, - timeout=60 * 60, ) -def eval(input_bytes, fast, furious): - import torch - import torchao - import os +class Model: + model_type: str = modal.parameter(default="large") + points_per_batch: int = modal.parameter(default=1024) + fast: int = modal.parameter(default=0) + furious: int = modal.parameter(default=0) - import subprocess - from pathlib import Path - from git import Repo + def calculate_file_hash(self, file_path, hash_algorithm='sha256'): + import hashlib + """Calculate the hash of a file.""" + hash_func = hashlib.new(hash_algorithm) + with open(file_path, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_func.update(chunk) + return hash_func.hexdigest() - def download_file(url, filename): + def download_file(self, url, filename): + import subprocess command = f"wget -O {filename} {url}" subprocess.run(command, shell=True, check=True) - os.chdir(Path(TARGET)) - download_file("https://raw.githubusercontent.com/pytorch/ao/refs/heads/climodal1/examples/sam2_amg_server/cli.py", "cli.py") - download_file("https://raw.githubusercontent.com/pytorch/ao/refs/heads/climodal1/examples/sam2_amg_server/server.py", "server.py") - # Create a Path object for the current directory - current_directory = Path('.') - - with open('/tmp/dog.jpg', 'wb') as file: - file.write(input_bytes) - - import sys - sys.path.append(".") - from cli import main as cli_main - cli_main(Path(TARGET) / Path("checkpoints"), - model_type="large", - input_path="/tmp/dog.jpg", - output_path="/tmp/dog_masked_2.png", - verbose=True, - fast=fast, - furious=furious) - - return bytearray(open('/tmp/dog_masked_2.png', 'rb').read()) - -@app.local_entrypoint() -def main(input_path, output_path, fast=False, furious=False): - bytes = eval.remote(open(input_path, 'rb').read(), fast, furious) - with open(output_path, "wb") as file: - file.write(bytes) + @modal.build() + @modal.enter() + def build(self): + import os + from torchao._models.sam2.build_sam import build_sam2 + from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator + + download_url_branch = "climodal2" + download_url = f"{DOWNLOAD_URL_BASE}/{download_url_branch}/" + download_url += "examples/sam2_amg_server/" + + h = self.calculate_file_hash(TARGET + "data/cli.py") + print("cli.py hash: ", h) + if h != "b38d60cb6fad555ad3c33081672ae981a5e4e744199355dfd24d395d20dfefda": + self.download_file(download_url + "cli.py", TARGET + "data/cli.py") + + h = self.calculate_file_hash(TARGET + "data/server.py") + print("server.py hash: ", h) + if h != "af33fdb9bcfe668b7764cb9c86f5fa9a799c999306e7c7e5b28c988b2616a0ae": + self.download_file(download_url + "server.py", TARGET + "data/server.py") + + os.chdir(Path(TARGET + "data")) + import sys + sys.path.append(".") + + from server import model_type_to_paths + from server import set_fast + from server import set_furious + + + device = "cuda" + checkpoint_path = Path(TARGET) / Path("checkpoints") + sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, self.model_type) + sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) + mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=self.points_per_batch, output_mode="uncompressed_rle") + self.mask_generator = mask_generator + if self.fast: + set_fast(mask_generator) + if self.furious: + set_furious(mask_generator) + + @modal.method() + def inference_rle(self, input_bytes) -> dict: + import os + os.chdir(Path(TARGET + "data")) + import sys + sys.path.append(".") + from server import file_bytes_to_image_tensor + from server import masks_to_rle_dict + image_tensor = file_bytes_to_image_tensor(input_bytes) + masks = self.mask_generator.generate(image_tensor) + return masks_to_rle_dict(masks) + + @modal.method() + def inference(self, input_bytes, output_format='png'): + import os + os.chdir(Path(TARGET + "data")) + import sys + sys.path.append(".") + from server import file_bytes_to_image_tensor + from server import show_anns + image_tensor = file_bytes_to_image_tensor(input_bytes) + masks = self.mask_generator.generate(image_tensor) + + import matplotlib.pyplot as plt + from io import BytesIO + from torchao._models.sam2.utils.amg import rle_to_mask + plt.figure(figsize=(image_tensor.shape[1]/100., image_tensor.shape[0]/100.), dpi=100) + plt.imshow(image_tensor) + show_anns(masks, rle_to_mask) + plt.axis('off') + plt.tight_layout() + buf = BytesIO() + plt.savefig(buf, format=output_format) + buf.seek(0) + return buf.getvalue() + + +def main(input_path, output_path, fast=False, furious=False, model_type="large", output_rle=False): + input_bytes = bytearray(open(input_path, 'rb').read()) + try: + model = modal.Cls.lookup("torchao-sam-2-cli", "Model")() + except modal.exception.NotFoundError: + print("Can't find running app. To deploy the app run the following command. Note that this costs money! See https://modal.com/pricing") + print("modal deploy cli_on_modal.py") + return + + if output_rle: + output_dict = model.inference_rle.remote(input_bytes) + with open(output_path, "w") as file: + file.write(json.dumps(output_dict, indent=4)) + else: + output_bytes = model.inference.remote(input_bytes) + with open(output_path, "wb") as file: + file.write(output_bytes) + + +if __name__ == "__main__": + fire.Fire(main)