@@ -551,7 +551,7 @@ def __post_init__(self) -> None:
551551 # For pooling models, self.task is used to indicate the
552552 # user-selected task
553553 if self .task == "score" :
554- if self .registry . is_cross_encoder_model (self .architectures ):
554+ if self ._is_classify_task (self .architectures ):
555555 self .task = "classify"
556556 else :
557557 self .task = "embed"
@@ -806,21 +806,24 @@ def _verify_tokenizer_mode(self) -> None:
806806 f"one of { get_args (TokenizerMode )} ." )
807807 self .tokenizer_mode = tokenizer_mode
808808
809+ def _is_classify_task (self , architectures : list [str ]):
810+ for arch in architectures :
811+ if arch .endswith ("ForSequenceClassification" ):
812+ return True
813+ return self .registry .is_cross_encoder_model (architectures )
814+
809815 def _get_preferred_pooling_task (
810816 self ,
811817 architectures : list [str ],
812818 ) -> _ResolvedTask :
813819 model_id = self .model
814820 if get_pooling_config (model_id , self .revision ):
815821 return "embed"
816- if self .registry .is_cross_encoder_model (architectures ):
817- return "classify"
818822 if self .registry .is_transcription_model (architectures ):
819823 return "transcription"
820824
821825 suffix_to_preferred_task : list [tuple [str , _ResolvedTask ]] = [
822826 # Other models follow this pattern
823- ("ForSequenceClassification" , "classify" ),
824827 ("EmbeddingModel" , "embed" ),
825828 ("RewardModel" , "reward" ),
826829 ]
@@ -878,11 +881,14 @@ def _get_supported_tasks(
878881 self ,
879882 task_option : TaskOption ,
880883 ) -> dict [RunnerType , list [_ResolvedTask ]]:
881- return {
882- "generate" : self ._get_supported_generation_tasks (task_option ),
883- "pooling" : self ._get_supported_pooling_tasks (task_option ),
884- "draft" : ["draft" ]
885- }
884+ if self ._is_classify_task (self .architectures ):
885+ return {"generate" : [], "pooling" : ["classify" ], "draft" : []}
886+ else :
887+ return {
888+ "generate" : self ._get_supported_generation_tasks (task_option ),
889+ "pooling" : self ._get_supported_pooling_tasks (task_option ),
890+ "draft" : ["draft" ]
891+ }
886892
887893 def _get_supported_runner_types (
888894 self ,
@@ -925,12 +931,16 @@ def _resolve_runner(
925931 f"Available tasks for runner={ task_runner !r} : "
926932 f"{ supported_tasks [task_runner ]} " )
927933
934+ if "classify" in supported_tasks .get ("pooling" , []):
935+ # When multiple pooling tasks are present, default to
936+ # pooling (eg cross-encoder) for non-standard architectures.
937+ return "pooling"
938+
928939 suffix_to_preferred_runner : list [tuple [str , RunnerType ]] = [
929940 ("ForCausalLM" , "generate" ),
930941 ("ForConditionalGeneration" , "generate" ),
931942 ("ChatModel" , "generate" ),
932943 ("LMHeadModel" , "generate" ),
933- ("ForSequenceClassification" , "pooling" ),
934944 ("EmbeddingModel" , "pooling" ),
935945 ("RewardModel" , "pooling" ),
936946 ]
@@ -940,10 +950,6 @@ def _resolve_runner(
940950 if arch .endswith (suffix ) and pref_runner in supported_runner_types :
941951 return pref_runner
942952
943- if "classify" in supported_tasks .get ("pooling" , []):
944- # When multiple pooling tasks are present, default to
945- # pooling (eg cross-encoder) for non-standard architectures.
946- return "pooling"
947953 if "generate" in supported_runner_types :
948954 return "generate"
949955 if "pooling" in supported_runner_types :
@@ -1525,7 +1531,7 @@ def is_v1_compatible(self) -> bool:
15251531
15261532 @property
15271533 def is_matryoshka (self ) -> bool :
1528- return (hasattr ( self .hf_config , "matryoshka_dimensions" )
1534+ return (bool ( getattr ( self .hf_config , "matryoshka_dimensions" , None ) )
15291535 or getattr (self .hf_config , "is_matryoshka" , False ))
15301536
15311537 @property
@@ -1539,13 +1545,11 @@ def use_pad_token(self) -> bool:
15391545 return getattr (self .hf_config , "use_pad_token" , True )
15401546
15411547 def get_and_verify_max_len (self , max_model_len : int ):
1542- # For pooling models, the tokenizer's `model_max_length` is often a
1543- # reliable source for the maximum sequence length. However, for
1544- # generative models, this can be incorrect and unduly limit the
1545- # context window (e.g., DeepSeek-R1). Therefore, we only consider
1546- # tokenizer_config for pooling models.
1548+ # Consider max_model_len in tokenizer_config only when
1549+ # pooling models use absolute position_embedding.
15471550 tokenizer_config = None
1548- if self .runner_type == "pooling" :
1551+ if (self .runner_type == "pooling" and getattr (
1552+ self .hf_config , "position_embedding_type" , "" ) == "absolute" ):
15491553 tokenizer_config = try_get_tokenizer_config (
15501554 self .tokenizer ,
15511555 trust_remote_code = self .trust_remote_code ,
0 commit comments