Skip to content

Commit

Permalink
no autocast
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch committed Feb 21, 2025
1 parent eb23791 commit 589811f
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 14 deletions.
6 changes: 4 additions & 2 deletions examples/sam2_vos_example/video_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def main(
)

# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
# torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

Expand All @@ -291,7 +291,9 @@ def main(
# hydra_overrides_extra=hydra_overrides_extra,
)
predictor._frame_batch_size = frame_batch_size
predictor.image_encoder.trunk = predictor.image_encoder.trunk.to(torch.bfloat16)
# predictor.image_encoder = predictor.image_encoder.to(torch.bfloat16)
predictor = predictor.to(torch.bfloat16)
predictor.sam_mask_decoder._src_dtype = torch.bfloat16
from torchao._models.sam2.modeling.sam.transformer import RoPEAttention

rope_attention_modules = [
Expand Down
2 changes: 1 addition & 1 deletion torchao/_models/sam2/modeling/position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = coords @ (self.positional_encoding_gaussian_matrix.float())
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
Expand Down
20 changes: 10 additions & 10 deletions torchao/_models/sam2/modeling/sam2_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,8 @@ def _forward_sam_heads(
object_score_logits,
) = self.sam_mask_decoder(
image_embeddings=backbone_features,
image_pe=self.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
image_pe=self.sam_prompt_encoder.get_dense_pe().to(torch.bfloat16),
sparse_prompt_embeddings=sparse_embeddings.to(torch.bfloat16),
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
repeat_image=False, # the image is already batched
Expand Down Expand Up @@ -639,7 +639,7 @@ def _prepare_memory_conditioned_features(
.to(device=device, non_blocking=True)
)
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
obj_pos = self.obj_ptr_tpos_proj(obj_pos.to(torch.bfloat16))
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
else:
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
Expand Down Expand Up @@ -672,10 +672,10 @@ def _prepare_memory_conditioned_features(
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)

with torch.autograd.profiler.record_function("self.memory_attention"):
current_vision_feats = [c.clone() for c in current_vision_feats]
current_vision_pos_embeds = [c.clone() for c in current_vision_pos_embeds]
memory = memory.clone()
memory_pos_embed = memory_pos_embed.clone()
current_vision_feats = [c.clone().to(torch.bfloat16) for c in current_vision_feats]
current_vision_pos_embeds = [c.clone().to(torch.bfloat16) for c in current_vision_pos_embeds]
memory = memory.clone().to(torch.bfloat16)
memory_pos_embed = memory_pos_embed.clone().to(torch.bfloat16)
pix_feat_with_mem = self.memory_attention(
curr=current_vision_feats,
curr_pos=current_vision_pos_embeds,
Expand Down Expand Up @@ -726,8 +726,8 @@ def _encode_new_memory(
# pix_feat = pix_feat.clone()
# mask_for_mem = mask_for_mem.clone()
maskmem_out = self.memory_encoder(
pix_feat,
mask_for_mem,
pix_feat.to(torch.bfloat16),
mask_for_mem.to(torch.bfloat16),
skip_mask_sigmoid=True, # sigmoid already applied
)
maskmem_features = maskmem_out["vision_features"].clone()
Expand Down Expand Up @@ -802,7 +802,7 @@ def _track_step(
point_inputs = {k: point_inputs[k].contiguous() for k in point_inputs}
with torch.autograd.profiler.record_function("self._forward_sam_heads"):
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat.contiguous(),
backbone_features=pix_feat.contiguous().to(torch.bfloat16),
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=[h.contiguous() for h in high_res_features],
Expand Down
2 changes: 1 addition & 1 deletion torchao/_models/sam2/sam2_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ def _get_image_feature(self, inference_state, frame_idx, batch_size):
if backbone_out is None:
# Cache miss -- we will run inference on a single image
device = inference_state["device"]
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
image = inference_state["images"][frame_idx].to(device).bfloat16().unsqueeze(0)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future).
Expand Down

0 comments on commit 589811f

Please sign in to comment.