@@ -115,47 +115,77 @@ def __init__(
115115 self .theta = theta
116116 self ._causal_rope_fix = _causal_rope_fix
117117
118- def forward (
118+ def _prepare_video_coords (
119119 self ,
120- hidden_states : torch . Tensor ,
120+ batch_size : int ,
121121 num_frames : int ,
122122 height : int ,
123123 width : int ,
124- frame_rate : Optional [int ] = None ,
125- rope_interpolation_scale : Optional [Tuple [torch .Tensor , float , float ]] = None ,
126- ) -> Tuple [torch .Tensor , torch .Tensor ]:
127- batch_size = hidden_states .size (0 )
128-
124+ rope_interpolation_scale : Tuple [torch .Tensor , float , float ],
125+ frame_rate : float ,
126+ device : torch .device ,
127+ ) -> torch .Tensor :
129128 # Always compute rope in fp32
130- grid_h = torch .arange (height , dtype = torch .float32 , device = hidden_states . device )
131- grid_w = torch .arange (width , dtype = torch .float32 , device = hidden_states . device )
132- grid_f = torch .arange (num_frames , dtype = torch .float32 , device = hidden_states . device )
129+ grid_h = torch .arange (height , dtype = torch .float32 , device = device )
130+ grid_w = torch .arange (width , dtype = torch .float32 , device = device )
131+ grid_f = torch .arange (num_frames , dtype = torch .float32 , device = device )
133132 grid = torch .meshgrid (grid_f , grid_h , grid_w , indexing = "ij" )
134133 grid = torch .stack (grid , dim = 0 )
135134 grid = grid .unsqueeze (0 ).repeat (batch_size , 1 , 1 , 1 , 1 )
136135
137- if rope_interpolation_scale is not None :
138- if isinstance (rope_interpolation_scale , tuple ):
139- # This will be deprecated in v0.34.0
140- grid [:, 0 :1 ] = grid [:, 0 :1 ] * rope_interpolation_scale [0 ] * self .patch_size_t / self .base_num_frames
141- grid [:, 1 :2 ] = grid [:, 1 :2 ] * rope_interpolation_scale [1 ] * self .patch_size / self .base_height
142- grid [:, 2 :3 ] = grid [:, 2 :3 ] * rope_interpolation_scale [2 ] * self .patch_size / self .base_width
136+ if isinstance (rope_interpolation_scale , tuple ):
137+ # This will be deprecated in v0.34.0
138+ grid [:, 0 :1 ] = grid [:, 0 :1 ] * rope_interpolation_scale [0 ] * self .patch_size_t / self .base_num_frames
139+ grid [:, 1 :2 ] = grid [:, 1 :2 ] * rope_interpolation_scale [1 ] * self .patch_size / self .base_height
140+ grid [:, 2 :3 ] = grid [:, 2 :3 ] * rope_interpolation_scale [2 ] * self .patch_size / self .base_width
141+ else :
142+ if not self ._causal_rope_fix :
143+ grid [:, 0 :1 ] = grid [:, 0 :1 ] * rope_interpolation_scale [0 :1 ] * self .patch_size_t / self .base_num_frames
143144 else :
144- if not self ._causal_rope_fix :
145- grid [:, 0 :1 ] = (
146- grid [:, 0 :1 ] * rope_interpolation_scale [0 :1 ] * self .patch_size_t / self .base_num_frames
147- )
148- else :
149- grid [:, 0 :1 ] = (
150- ((grid [:, 0 :1 ] - 1 ) * rope_interpolation_scale [0 :1 ] + 1 / frame_rate ).clamp (min = 0 )
151- * self .patch_size_t
152- / self .base_num_frames
153- )
154- grid [:, 1 :2 ] = grid [:, 1 :2 ] * rope_interpolation_scale [1 :2 ] * self .patch_size / self .base_height
155- grid [:, 2 :3 ] = grid [:, 2 :3 ] * rope_interpolation_scale [2 :3 ] * self .patch_size / self .base_width
145+ grid [:, 0 :1 ] = (
146+ ((grid [:, 0 :1 ] - 1 ) * rope_interpolation_scale [0 :1 ] + 1 / frame_rate ).clamp (min = 0 )
147+ * self .patch_size_t
148+ / self .base_num_frames
149+ )
150+ grid [:, 1 :2 ] = grid [:, 1 :2 ] * rope_interpolation_scale [1 :2 ] * self .patch_size / self .base_height
151+ grid [:, 2 :3 ] = grid [:, 2 :3 ] * rope_interpolation_scale [2 :3 ] * self .patch_size / self .base_width
156152
157153 grid = grid .flatten (2 , 4 ).transpose (1 , 2 )
158154
155+ return grid
156+
157+ def forward (
158+ self ,
159+ hidden_states : torch .Tensor ,
160+ num_frames : Optional [int ] = None ,
161+ height : Optional [int ] = None ,
162+ width : Optional [int ] = None ,
163+ frame_rate : Optional [int ] = None ,
164+ rope_interpolation_scale : Optional [Tuple [torch .Tensor , float , float ]] = None ,
165+ video_coords : Optional [torch .Tensor ] = None ,
166+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
167+ batch_size = hidden_states .size (0 )
168+
169+ if video_coords is None :
170+ grid = self ._prepare_video_coords (
171+ batch_size ,
172+ num_frames ,
173+ height ,
174+ width ,
175+ rope_interpolation_scale = rope_interpolation_scale ,
176+ frame_rate = frame_rate ,
177+ device = hidden_states .device ,
178+ )
179+ else :
180+ grid = torch .stack (
181+ [
182+ video_coords [:, 0 ] / self .base_num_frames ,
183+ video_coords [:, 1 ] / self .base_height ,
184+ video_coords [:, 2 ] / self .base_width ,
185+ ],
186+ dim = - 1 ,
187+ )
188+
159189 start = 1.0
160190 end = self .theta
161191 freqs = self .theta ** torch .linspace (
@@ -387,11 +417,12 @@ def forward(
387417 encoder_hidden_states : torch .Tensor ,
388418 timestep : torch .LongTensor ,
389419 encoder_attention_mask : torch .Tensor ,
390- num_frames : int ,
391- height : int ,
392- width : int ,
393- frame_rate : int ,
420+ num_frames : Optional [ int ] = None ,
421+ height : Optional [ int ] = None ,
422+ width : Optional [ int ] = None ,
423+ frame_rate : Optional [ int ] = None ,
394424 rope_interpolation_scale : Optional [Union [Tuple [float , float , float ], torch .Tensor ]] = None ,
425+ video_coords : Optional [torch .Tensor ] = None ,
395426 attention_kwargs : Optional [Dict [str , Any ]] = None ,
396427 return_dict : bool = True ,
397428 ) -> torch .Tensor :
@@ -414,7 +445,9 @@ def forward(
414445 msg = "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0."
415446 deprecate ("rope_interpolation_scale" , "0.34.0" , msg )
416447
417- image_rotary_emb = self .rope (hidden_states , num_frames , height , width , frame_rate , rope_interpolation_scale )
448+ image_rotary_emb = self .rope (
449+ hidden_states , num_frames , height , width , frame_rate , rope_interpolation_scale , video_coords
450+ )
418451
419452 # convert encoder_attention_mask to a bias the same way we do for attention_mask
420453 if encoder_attention_mask is not None and encoder_attention_mask .ndim == 2 :
0 commit comments