Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch committed Nov 27, 2024
1 parent ac87cab commit 5abcb9c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 22 deletions.
15 changes: 4 additions & 11 deletions examples/sam2_amg_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main_docstring():
"""


def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=False):
def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""):
device = "cuda"
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)
if verbose:
Expand All @@ -33,17 +33,10 @@ def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=102
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")
if furious:
set_furious(mask_generator)
print("load_fast: ", load_fast)
if load_fast:
import time
t0 = time.time()
print(f"Start load. {t0}")
load_aot_fast(mask_generator)
print(f"End load. {time.time() - t0}")
load_aot_fast(mask_generator, load_fast)
if fast:
set_aot_fast(mask_generator)
import sys; sys.exit(1)
set_fast(mask_generator)
set_fast(mask_generator, load_fast)

image_tensor = file_bytes_to_image_tensor(input_bytes)
if verbose:
Expand All @@ -62,7 +55,7 @@ def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=102
buf.seek(0)
return buf.getvalue()

def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=False):
def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""):
input_bytes = bytearray(open(input_path, 'rb').read())
output_bytes = main_headless(checkpoint_path,
model_type,
Expand Down
20 changes: 9 additions & 11 deletions examples/sam2_amg_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,10 @@ def get_dense_pe(self, *args, **kwargs) -> torch.Tensor:
return self.other.get_dense_pe(*args, **kwargs)

def load_aot_fast(mask_generator, model_directory):
t0 = time.time()
path = Path(model_directory) / Path(f"sam2_image_encoder.pt2")
assert path.exists(), f"Expected {path} to exist."
print(f"Start load from {path}")
pkg = torch._inductor.aoti_load_package(str(path))
pkg_m = LoadedModel(pkg)
mask_generator.predictor.model.image_encoder = pkg_m
Expand All @@ -487,9 +489,11 @@ def load_aot_fast(mask_generator, model_directory):
# pkg_m = LoadedModel(pkg)
# mask_generator.predictor.model.sam_mask_decoder.transformer = pkg_m

print(f"End load. Took {time.time() - t0}s")

def set_fast(mask_generator, load_fast=False):
if not load_fast:

def set_fast(mask_generator, load_fast=""):
if load_fast == "":
# TODO: Using CUDA graphs can cause numerical differences?
mask_generator.predictor.model.image_encoder = torch.compile(
mask_generator.predictor.model.image_encoder,
Expand Down Expand Up @@ -569,19 +573,13 @@ def main(checkpoint_path,
logging.info(f"Using {points_per_batch} points_per_batch")
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")


if load_fast != "":
print(f"Loading compiled models from {load_fast}")
import time
t0 = time.time()
print(f"Start load. {t0}")
load_aot_fast(mask_generator, load_fast)
print(f"End load. {time.time() - t0}")

if save_fast != "":
assert load_fast == "", "Can't save compiled models while loading them with --load-fast."
assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible."
print(f"Saving compiled models to {save_fast}")
print(f"Saving compiled models under directory {save_fast}")
set_aot_fast(mask_generator, save_fast)

if fast:
Expand All @@ -594,9 +592,9 @@ def main(checkpoint_path,
elif use_autoquant:
from torchao import autoquant
from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
mask_generator.predictor.model = autoquant(mask_generator.predictor.model, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)
mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40)

mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16, min_sqnr=40)
# mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16, min_sqnr=40)
# NOTE: Not baseline feature
mask_generator.predictor._transforms_device = mask_generator.predictor.device
torch.set_float32_matmul_precision('high')
Expand Down

0 comments on commit 5abcb9c

Please sign in to comment.