Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT: SAM2 bfloat16 (without autocast) and more compile #1757

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/sam2_vos_example/compile_export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def load_exported_model(

def set_fast(predictor, loaded_exported_model=False):
if not loaded_exported_model:
predictor.image_encoder.trunk.forward = torch.compile(
predictor.image_encoder.trunk.forward,
predictor.image_encoder.forward = torch.compile(
predictor.image_encoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False,
Expand Down
6 changes: 4 additions & 2 deletions examples/sam2_vos_example/video_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def main(
)

# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
# torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

Expand All @@ -291,7 +291,9 @@ def main(
# hydra_overrides_extra=hydra_overrides_extra,
)
predictor._frame_batch_size = frame_batch_size
predictor.image_encoder.trunk = predictor.image_encoder.trunk.to(torch.bfloat16)
# predictor.image_encoder = predictor.image_encoder.to(torch.bfloat16)
predictor = predictor.to(torch.bfloat16)
predictor.sam_mask_decoder._src_dtype = torch.bfloat16
from torchao._models.sam2.modeling.sam.transformer import RoPEAttention

rope_attention_modules = [
Expand Down
2 changes: 1 addition & 1 deletion torchao/_models/sam2/modeling/position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = coords @ (self.positional_encoding_gaussian_matrix.float())
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
Expand Down
67 changes: 37 additions & 30 deletions torchao/_models/sam2/modeling/sam2_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,8 @@ def _forward_sam_heads(
object_score_logits,
) = self.sam_mask_decoder(
image_embeddings=backbone_features,
image_pe=self.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
image_pe=self.sam_prompt_encoder.get_dense_pe().to(torch.bfloat16),
sparse_prompt_embeddings=sparse_embeddings.to(torch.bfloat16),
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
repeat_image=False, # the image is already batched
Expand Down Expand Up @@ -469,7 +469,8 @@ def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs)

def forward_image(self, img_batch: torch.Tensor):
"""Get the image feature on the input batch."""
backbone_out = self.image_encoder(img_batch)
backbone_out = self.image_encoder(img_batch.clone())
backbone_out = {k: [c.clone() for c in backbone_out[k]] if type(backbone_out[k]) == list else backbone_out[k].clone() for k in backbone_out}
if self.use_high_res_features_in_sam:
# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
Expand Down Expand Up @@ -638,7 +639,7 @@ def _prepare_memory_conditioned_features(
.to(device=device, non_blocking=True)
)
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
obj_pos = self.obj_ptr_tpos_proj(obj_pos.to(torch.bfloat16))
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
else:
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
Expand Down Expand Up @@ -670,18 +671,20 @@ def _prepare_memory_conditioned_features(
memory = torch.cat(to_cat_memory, dim=0)
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)

current_vision_feats = [c.clone() for c in current_vision_feats]
current_vision_pos_embeds = [c.clone() for c in current_vision_pos_embeds]
memory = memory.clone()
memory_pos_embed = memory_pos_embed.clone()
pix_feat_with_mem = self.memory_attention(
curr=current_vision_feats,
curr_pos=current_vision_pos_embeds,
memory=memory,
memory_pos=memory_pos_embed,
num_obj_ptr_tokens=num_obj_ptr_tokens,
)
pix_feat_with_mem = pix_feat_with_mem.clone()
with torch.autograd.profiler.record_function("self.memory_attention"):
current_vision_feats = [c.clone().to(torch.bfloat16) for c in current_vision_feats]
current_vision_pos_embeds = [c.clone().to(torch.bfloat16) for c in current_vision_pos_embeds]
memory = memory.clone().to(torch.bfloat16)
memory_pos_embed = memory_pos_embed.clone().to(torch.bfloat16)
pix_feat_with_mem = self.memory_attention(
curr=current_vision_feats,
curr_pos=current_vision_pos_embeds,
memory=memory,
memory_pos=memory_pos_embed,
num_obj_ptr_tokens=num_obj_ptr_tokens,
)
pix_feat_with_mem = pix_feat_with_mem.clone()

# reshape the output (HW)BC => BCHW
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
return pix_feat_with_mem
Expand Down Expand Up @@ -719,11 +722,14 @@ def _encode_new_memory(
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
if self.sigmoid_bias_for_mem_enc != 0.0:
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
maskmem_out = self.memory_encoder(
pix_feat,
mask_for_mem,
skip_mask_sigmoid=True, # sigmoid already applied
)
with torch.autograd.profiler.record_function("self.memory_encoder"):
# pix_feat = pix_feat.clone()
# mask_for_mem = mask_for_mem.clone()
maskmem_out = self.memory_encoder(
pix_feat.to(torch.bfloat16),
mask_for_mem.to(torch.bfloat16),
skip_mask_sigmoid=True, # sigmoid already applied
)
maskmem_features = maskmem_out["vision_features"].clone()
maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]]
# add a no-object embedding to the spatial memory to indicate that the frame
Expand Down Expand Up @@ -794,13 +800,14 @@ def _track_step(
assert multimask_output
if point_inputs is not None:
point_inputs = {k: point_inputs[k].contiguous() for k in point_inputs}
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat.contiguous(),
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=[h.contiguous() for h in high_res_features],
multimask_output=multimask_output,
)
with torch.autograd.profiler.record_function("self._forward_sam_heads"):
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat.contiguous().to(torch.bfloat16),
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=[h.contiguous() for h in high_res_features],
multimask_output=multimask_output,
)

return current_out, sam_outputs, high_res_features, pix_feat

Expand Down Expand Up @@ -854,7 +861,7 @@ def track_step(
current_out, sam_outputs, _, _ = self._track_step(
frame_idx,
is_init_cond_frame,
current_vision_feats,
[c.clone() for c in current_vision_feats],
current_vision_pos_embeds,
feat_sizes,
point_inputs,
Expand Down Expand Up @@ -886,7 +893,7 @@ def track_step(
# Finally run the memory encoder on the predicted mask to encode
# it into a new memory feature (that can be used in future frames)
self._encode_memory_in_output(
current_vision_feats,
[c.clone() for c in current_vision_feats],
feat_sizes,
point_inputs,
run_mem_encoder,
Expand Down
2 changes: 1 addition & 1 deletion torchao/_models/sam2/sam2_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ def _get_image_feature(self, inference_state, frame_idx, batch_size):
if backbone_out is None:
# Cache miss -- we will run inference on a single image
device = inference_state["device"]
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
image = inference_state["images"][frame_idx].to(device).bfloat16().unsqueeze(0)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future).
Expand Down
Loading