@@ -393,6 +393,15 @@ def get_input_positions_tensor(
393393 context_len = context_len ,
394394 seq_len = seq_len ,
395395 )
396+ elif hf_config .model_type in ["ernie4_5_moe_vl" , "ernie4_5_vl" ]:
397+ return cls ._ernie_get_input_positions_tensor (
398+ input_tokens = input_tokens ,
399+ hf_config = hf_config ,
400+ image_grid_thw = image_grid_thw ,
401+ video_grid_thw = video_grid_thw ,
402+ context_len = context_len ,
403+ seq_len = seq_len ,
404+ )
396405 else :
397406 return cls ._vl_get_input_positions_tensor (
398407 input_tokens = input_tokens ,
@@ -513,6 +522,120 @@ def _glm4v_get_input_positions_tensor(
513522 len (input_tokens )).item ()
514523 return llm_positions , mrope_position_delta
515524
525+ @classmethod
526+ def _ernie_get_input_positions_tensor (
527+ cls ,
528+ input_tokens : list [int ],
529+ hf_config : PretrainedConfig ,
530+ image_grid_thw : Union [list [list [int ]], torch .Tensor ],
531+ video_grid_thw : Union [list [list [int ]], torch .Tensor ],
532+ context_len : int = 0 ,
533+ seq_len : Optional [int ] = None ,
534+ ) -> tuple [torch .Tensor , int ]:
535+ """Get mrope input positions and delta value for Ernie VL."""
536+
537+ image_token_id = hf_config .im_patch_id
538+ video_start_token_id = hf_config .video_start_token_id
539+ video_end_token_id = hf_config .video_end_token_id
540+ spatial_conv_size = hf_config .spatial_conv_size
541+ temporal_conv_size = hf_config .temporal_conv_size
542+ llm_pos_ids_list : list = []
543+
544+ if not (image_grid_thw is None and video_grid_thw is None ):
545+ if isinstance (image_grid_thw , torch .Tensor ):
546+ image_grid_thw = image_grid_thw .tolist ()
547+
548+ input_token_type : list [str ] = []
549+ video_check_flg = False
550+ for token in input_tokens :
551+ if token == video_start_token_id :
552+ video_check_flg = True
553+ elif token == video_end_token_id :
554+ video_check_flg = False
555+
556+ if (token == image_token_id ) and (video_check_flg is False ):
557+ input_token_type .append ("image" )
558+ elif (token == image_token_id ) and (video_check_flg is True ):
559+ input_token_type .append ("video" )
560+ else :
561+ input_token_type .append ("text" )
562+
563+ input_type_group : list [tuple [str , int , int ]] = []
564+ for key , group_iter in itertools .groupby (
565+ enumerate (input_token_type ), lambda x : x [1 ]):
566+ group_list = list (group_iter )
567+ start_index = group_list [0 ][0 ]
568+ end_index = group_list [- 1 ][0 ] + 1
569+ input_type_group .append ((key , start_index , end_index ))
570+
571+ video_frame_num = 1
572+ mm_data_idx = 0
573+ for modality_type , start_idx , end_idx in input_type_group :
574+ st_idx = llm_pos_ids_list [- 1 ].max () + 1 if len (
575+ llm_pos_ids_list ) > 0 else 0
576+ if modality_type == "image" :
577+ t , h , w = (
578+ image_grid_thw [mm_data_idx ][0 ],
579+ image_grid_thw [mm_data_idx ][1 ],
580+ image_grid_thw [mm_data_idx ][2 ],
581+ )
582+ llm_grid_t , llm_grid_h , llm_grid_w = \
583+ t , h // spatial_conv_size , w // spatial_conv_size
584+
585+ t_index = torch .arange (llm_grid_t ).view (- 1 , 1 ).expand (
586+ - 1 , llm_grid_h * llm_grid_w ).flatten ()
587+ h_index = torch .arange (llm_grid_h ).view (1 , - 1 , 1 ).expand (
588+ llm_grid_t , - 1 , llm_grid_w ).flatten ()
589+ w_index = torch .arange (llm_grid_w ).view (1 , 1 , - 1 ).expand (
590+ llm_grid_t , llm_grid_h , - 1 ).flatten ()
591+ llm_pos_ids_list .append (
592+ torch .stack ([t_index , h_index , w_index ]) + st_idx )
593+ mm_data_idx += 1
594+
595+ elif modality_type == "video" :
596+ t , h , w = (
597+ video_grid_thw [mm_data_idx ][0 ],
598+ video_grid_thw [mm_data_idx ][1 ],
599+ video_grid_thw [mm_data_idx ][2 ],
600+ )
601+ llm_grid_t , llm_grid_h , llm_grid_w = (t //
602+ temporal_conv_size ,
603+ h //
604+ spatial_conv_size ,
605+ w //
606+ spatial_conv_size )
607+
608+ for t_idx in range (llm_grid_t ):
609+ t_index = torch .tensor (t_idx ).view (- 1 , 1 ).expand (
610+ - 1 , llm_grid_h * llm_grid_w ).flatten ()
611+ h_index = torch .arange (llm_grid_h ).view (
612+ 1 , - 1 , 1 ).expand (1 , - 1 , llm_grid_w ).flatten ()
613+ w_index = torch .arange (llm_grid_w ).view (
614+ 1 , 1 , - 1 ).expand (1 , llm_grid_h , - 1 ).flatten ()
615+ llm_pos_ids_list .append (
616+ torch .stack ([t_index , h_index , w_index ]) + st_idx )
617+
618+ mm_data_idx += 1
619+ video_frame_num += 1
620+
621+ else :
622+ text_len = end_idx - start_idx
623+ llm_pos_ids_list .append (
624+ torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ) +
625+ st_idx )
626+ video_frame_num = 1
627+
628+ else :
629+ text_len = len (input_tokens )
630+ llm_pos_ids_list .append (
631+ torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ))
632+
633+ llm_positions = torch .cat (llm_pos_ids_list , dim = 1 ).reshape (3 , - 1 )
634+ llm_positions = llm_positions [:, context_len :seq_len ]
635+ mrope_position_delta = (llm_positions .max () + 1 -
636+ len (input_tokens )).item ()
637+ return llm_positions , mrope_position_delta
638+
516639 @classmethod
517640 def _vl_get_input_positions_tensor (
518641 cls ,
0 commit comments