File tree Expand file tree Collapse file tree 1 file changed +8
-6
lines changed Expand file tree Collapse file tree 1 file changed +8
-6
lines changed Original file line number Diff line number Diff line change @@ -240,17 +240,19 @@ def set_lora(
240240 def forward (self , x : torch .Tensor ) -> torch .Tensor :
241241 added_tokens_mask = torch .where (x > self .base_layer .org_vocab_size - 1 ,
242242 1 , 0 )
243- embeddings_indices = torch .narrow (
244- self .punica_wrapper ._embeddings_indices , 1 , 0 , x .size (0 ))
245243
246- indices = embeddings_indices [1 ]
244+ # NB: Don't use torch.narrow here. torch.narrow triggers some
245+ # Dynamic Shape specialization in torch.compile
246+ num_tokens = x .shape [0 ]
247+ indices_1 = self .punica_wrapper ._embeddings_indices [1 ][:num_tokens ]
248+ indices_0 = self .punica_wrapper ._embeddings_indices [0 ][:num_tokens ]
249+
247250 full_lora_a_embeddings = F .embedding (
248- x + indices ,
251+ x + indices_1 ,
249252 self .lora_a_stacked_2d ,
250253 )
251- indices = embeddings_indices [0 ]
252254 full_output = self .base_layer .forward (x +
253- (indices * added_tokens_mask ))
255+ (indices_0 * added_tokens_mask ))
254256
255257 full_output_org = full_output
256258 if full_output .ndim == 3 :
You can’t perform that action at this time.
0 commit comments