Skip to content

Commit

Permalink
SAM2 Fast AMG: memory profiling and more compile (#1296)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored Nov 16, 2024
1 parent 6f17810 commit d4ca98f
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 42 deletions.
10 changes: 5 additions & 5 deletions examples/sam2_amg_server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ Experiments run on H100 and with batch size 1
| mode | mIoU | mask count mismatch | avg. ms per request | max. memory (MiB (%)) | batch size | points per batch |
| -------------- | ----------------- | ------------------- | ------------------- | --------------------- | ---------- | ---------------- |
| baseline | 1.0 | 0 | 863 | 4013MiB (4%) | 1 | 64 |
| ao | 0.9999980926513672 | 6 | 586 | | 1 | 64 |
| fast | 0.9937329888343811 | 191 | 333 | | 1 | 1024 |
| fast | 0.9937219619750977 | 192 | 324 | | 16 | 1024 |
| fast + furious | 0.9804400205612183 | 292 | 131 | | 1 | 1024 |
| fast + furious | 0.9806423187255859 | 282 | 130 | | 16 | 1024 |
| ao | 0.9999980926513672 | 6 | 586 | 3257MiB (3%) | 1 | 64 |
| fast | 0.993732988834381 | 191 | 326 | 27197MiB (27%) | 1 | 1024 |
| fast | 0.9937511086463928 | 194 | 315 | 27488MiB (28%) | 16 | 1024 |
| fast + furious | 0.9817246198654175 | 266 | 120 | 13616MiB (13%) | 1 | 1024 |
| fast + furious | 0.9794579744338989 | 274 | 122 | 13808MiB (14%) | 16 | 1024 |

mask count mismatch counts the number of requests where the number of masks differ from the baseline.
For example, the baseline may have chosen to segment an image into 18 masks, but the fast variant produces 17 or 19.
Expand Down
42 changes: 34 additions & 8 deletions examples/sam2_amg_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,26 @@ def profiler_runner(path, fn, *args, **kwargs):
return result


def memory_runner(path, fn, *args, **kwargs):
print("Start memory recording")
torch.cuda.synchronize()
torch.cuda.memory._record_memory_history(
True,
trace_alloc_max_entries=100000,
trace_alloc_record_context=True
)
result = fn(*args, **kwargs)
torch.cuda.synchronize()
snapshot = torch.cuda.memory._snapshot()
print("Finish memory recording")
import pickle
with open(path, 'wb') as f:
pickle.dump(snapshot, f)
# Use to convert pickle file into html
# python torch/cuda/_memory_viz.py trace_plot <snapshot>.pickle -o <snapshot>.html
return result


def image_tensor_to_masks(example_image, mask_generator):
masks = mask_generator.generate(example_image)
return masks
Expand Down Expand Up @@ -187,7 +207,7 @@ def process_batch(batch, mask_generator):
print(f"Processing batch of len {len(batch)} using generate_batch")
masks = mask_generator.generate_batch(image_tensors)
print(f"Took avg. {(time.time() - t) / len(batch)}s per batch entry")
# max_memory_allocated()
max_memory_allocated()
return masks


Expand Down Expand Up @@ -259,6 +279,7 @@ def main(checkpoint_path,
unittest=False,
benchmark=False,
profile=None,
memory_profile=None,
verbose=False,
points_per_batch=64,
port=5000,
Expand Down Expand Up @@ -305,13 +326,6 @@ def main(checkpoint_path,
dynamic=False,
)

mask_generator.predictor.model.sam_prompt_encoder.forward = torch.compile(
mask_generator.predictor.model.sam_prompt_encoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)

mask_generator.predictor._predict_masks = torch.compile(
mask_generator.predictor._predict_masks,
mode="max-autotune",
Expand All @@ -329,6 +343,7 @@ def main(checkpoint_path,
mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16)
# NOTE: Not baseline feature
mask_generator.predictor._image_dtype = torch.float16
mask_generator.predictor._transforms_device = mask_generator.predictor.device
torch.set_float32_matmul_precision('high')
mask_generator.predictor.model.sam_mask_decoder = mask_generator.predictor.model.sam_mask_decoder.to(torch.float16)
# NOTE: Not baseline feature
Expand Down Expand Up @@ -363,11 +378,15 @@ def main(checkpoint_path,
for i, shapes in enumerate([example_shapes(), example_shapes_2()]):
print(f"batch size {batch_size} example shapes {i} benchmark")
random_images = [np.random.randint(0, 256, size=size, dtype=np.uint8) for size in shapes]
if batch_size > len(random_images):
num_repeat = (len(random_images) + batch_size) // batch_size
random_images = num_repeat * random_images

if batch_size == 1:
[benchmark_fn(image_tensor_to_masks, r, mask_generator) for r in random_images]
else:
random_images = random_images[:batch_size]
print("len(random_images): ", len(random_images))
benchmark_fn(image_tensors_to_masks, random_images, mask_generator)

if profile is not None:
Expand All @@ -377,6 +396,13 @@ def main(checkpoint_path,
else:
profiler_runner(profile, image_tensors_to_masks, [image_tensor] * batch_size, mask_generator)

if memory_profile is not None:
print(f"Saving memory profile under {memory_profile}")
if batch_size == 1:
memory_runner(memory_profile, image_tensor_to_masks, image_tensor, mask_generator)
else:
memory_runner(memory_profile, image_tensors_to_masks, [image_tensor] * batch_size, mask_generator)

if dry:
return

Expand Down
4 changes: 0 additions & 4 deletions torchao/_models/sam2/automatic_mask_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,6 @@ def _process_crop_batch(
i = 0
batch_features = self.predictor._features
all_crop_data = []
all_all_batch_iterator_data = []
for (cropped_im, crop_box, layer_idx, orig_size) in zip(all_cropped_im, all_crop_box, all_layer_idx, all_orig_size):
cropped_im_size = cropped_im.shape[:2]
self.predictor.reset_predictor()
Expand Down Expand Up @@ -425,9 +424,6 @@ def _process_crop_batch(
data = self._process_batch_fullgraph(points, im_size, crop_box, crop_box_torch, orig_size, normalize, orig_box_torch)
all_batch_iterator_data.append(data)
self.predictor.reset_predictor()
all_all_batch_iterator_data.append(all_batch_iterator_data)

for all_batch_iterator_data in all_all_batch_iterator_data:

result_data = None
with torch.autograd.profiler.record_function("all mask_to_rle_pytorch_2"):
Expand Down
4 changes: 4 additions & 0 deletions torchao/_models/sam2/modeling/sam/mask_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ def predict_masks(
# TODO: Not specifying scale kwarg in SDPA will cause NaN here
# print("hs.isnan().any(): ", hs.isnan().any().item())

# TODO: These outputs are being immediately indexed.
# Is there something to remove?
# TODO: The fact that there's a crop box and we try to find stuff at the
# boundary later and there's generally cropping going on smells of padding.
iou_token_out = hs[:, s, :]
mask_tokens_out = hs[:, s + 1: (s + 1 + self.num_mask_tokens), :]

Expand Down
15 changes: 9 additions & 6 deletions torchao/_models/sam2/sam2_image_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
]

self._image_dtype = torch.float32
self._transforms_device = "cpu"

@classmethod
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
Expand Down Expand Up @@ -110,10 +111,10 @@ def set_image(
raise NotImplementedError("Image format not supported")

input_image = self._transforms.to_tensor(image)
# TODO: Doing these transforms on the GPU changes the numerics
# NOTE: Doing these transforms on the GPU changes the numerics
input_image = input_image.to(device=self._transforms_device)
input_image = self._transforms.transforms(input_image)
input_image = input_image.to(device=self.device)
# TODO: Doing this here instead causes masks to not match reference exactly
# input_image = self._transforms.transforms(input_image)
input_image = input_image[None, ...].to(dtype=self._image_dtype)

Expand Down Expand Up @@ -167,8 +168,10 @@ def set_image_batch(
len(img_batch.shape) == 4 and img_batch.shape[1] == 3
), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
logging.info("Computing image embeddings for the provided images...")
backbone_out = self.model.forward_image(img_batch)
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
with torch.autograd.profiler.record_function("forward_image"):
backbone_out = self.model.forward_image(img_batch)
with torch.autograd.profiler.record_function("_prepare_backbone_features"):
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
if self.model.directly_add_no_mem_embed:
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
Expand Down Expand Up @@ -462,11 +465,11 @@ def _predict_masks_postprocess(self, low_res_masks, img_idx, return_logits, chan
# Upscale the masks to the original image resolution
if channel_1:
masks = self._transforms.postprocess_masks_1_channel(
low_res_masks, self._orig_hw[img_idx]
low_res_masks, self._orig_hw[img_idx], self._image_dtype
)
else:
masks = self._transforms.postprocess_masks(
low_res_masks, self._orig_hw[img_idx]
low_res_masks, self._orig_hw[img_idx], self._image_dtype
)
low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
if not return_logits:
Expand Down
56 changes: 39 additions & 17 deletions torchao/_models/sam2/utils/amg.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
return mask.transpose() # Put in C order


def _mask_to_rle_pytorch_2_0(tensor: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor):
@torch.compile(fullgraph=True, dynamic=True)
def _mask_to_rle_pytorch_2_0_0(tensor: torch.Tensor) -> (torch.Tensor, torch.Tensor):
"""
Encodes masks to an uncompressed RLE, in the format expected by
pycoco tools.
Expand All @@ -227,33 +228,53 @@ def _mask_to_rle_pytorch_2_0(tensor: torch.Tensor) -> (torch.Tensor, torch.Tenso
with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: change indices"):
# Compute change indices
diff = tensor[:, 1:] ^ tensor[:, :-1]
a = torch.tensor([[True]])
if diff.is_cuda:
a = a.pin_memory().cuda()
# a = a.to(diff.device)
# a = torch.tensor([[True]])
a = torch.ones((1, 1), dtype=bool, device=diff.device)
# if diff.is_cuda:
# a = a.pin_memory().cuda()
# # a = a.to(diff.device)
a = a.expand_as(diff.narrow(1, 0, 1))
diff = torch.cat([a, diff, a], dim=1)
if diff.numel() > 2147483646:
num_chunks = (diff.numel() + 2147483646) // 2147483646
change_indices = torch.cat([d.nonzero() for d in diff.chunk(num_chunks)])
else:
change_indices = diff.nonzero()
return diff


@torch.compile(fullgraph=True, dynamic=True)
def _mask_to_rle_pytorch_2_0_1(tensor: torch.Tensor, diff: torch.Tensor, change_indices: torch.Tensor) -> (torch.Tensor, torch.Tensor):
tensor = tensor.permute(0, 2, 1).flatten(1)

with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: all_btw_idx"):
alt_lens = diff.sum(dim=1)

all_cur_idx = change_indices[:, 1]
all_btw_idx = torch.cat([all_cur_idx[1:], all_cur_idx[:1]]) - all_cur_idx
all_cur_idx_0 = all_cur_idx.narrow(0, 1, all_cur_idx.size(0) - 1)
all_cur_idx_1 = all_cur_idx.narrow(0, 0, 1)
all_btw_idx = torch.cat([all_cur_idx_0, all_cur_idx_1])
all_btw_idx = all_btw_idx - all_cur_idx

with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: Encode run length"):
alt_lens_nt = torch.nested.nested_tensor_from_jagged(all_btw_idx, lengths=alt_lens)
# Encode run length
counts_init = (tensor[:, 0] == 0)
return RLEData(alt_lens_nt=alt_lens_nt,
counts_init=counts_init,
b=b,
h=h,
w=w)
return alt_lens_nt, counts_init


def _mask_to_rle_pytorch_2_0(tensor: torch.Tensor) -> RLEData:
b, h, w = tensor.shape
with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: _mask_to_rle_pytorch_2_0_0"):
diff = _mask_to_rle_pytorch_2_0_0(tensor)
with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: nonzero"):
if diff.numel() > 2147483646:
num_chunks = (diff.numel() + 2147483646) // 2147483646
change_indices = torch.cat([d.nonzero() for d in diff.chunk(num_chunks)])
else:
change_indices = diff.nonzero()
with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: _mask_to_rle_pytorch_2_0_1"):
alt_lens_nt, counts_init = _mask_to_rle_pytorch_2_0_1(tensor, diff, change_indices)
return RLEData(alt_lens_nt=alt_lens_nt,
counts_init=counts_init,
b=b,
h=h,
w=w)


def _mask_to_rle_pytorch_2_1(rle_data: RLEData):
Expand All @@ -276,7 +297,8 @@ def _mask_to_rle_pytorch_2_1(rle_data: RLEData):


def mask_to_rle_pytorch_2(tensor: torch.Tensor) -> List[Dict[str, Any]]:
return _mask_to_rle_pytorch_2_1(_mask_to_rle_pytorch_2_0(tensor))
with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2"):
return _mask_to_rle_pytorch_2_1(_mask_to_rle_pytorch_2_0(tensor))


def area_from_rle(rle: Dict[str, Any]) -> int:
Expand Down
6 changes: 4 additions & 2 deletions torchao/_models/sam2/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def transform_boxes(
boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
return boxes

def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
def postprocess_masks(self, masks: torch.Tensor, orig_hw, output_dtype) -> torch.Tensor:
"""
Perform PostProcessing on output masks.
"""
Expand Down Expand Up @@ -114,10 +114,11 @@ def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
)
masks = input_masks

masks = masks.to(output_dtype)
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
return masks

def postprocess_masks_1_channel(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
def postprocess_masks_1_channel(self, masks: torch.Tensor, orig_hw, output_dtype) -> torch.Tensor:
"""
Perform PostProcessing on output masks.
"""
Expand Down Expand Up @@ -161,5 +162,6 @@ def postprocess_masks_1_channel(self, masks: torch.Tensor, orig_hw) -> torch.Ten
)
masks = input_masks

masks = masks.to(output_dtype)
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
return masks

0 comments on commit d4ca98f

Please sign in to comment.