@@ -7,7 +7,6 @@ use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
77use candle_nn:: { Embedding , VarBuilder } ;
88use serde:: Deserialize ;
99use text_embeddings_backend_core:: { Batch , ModelType , Pool } ;
10- use std:: str:: FromStr ;
1110
1211// https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/configuration_modernbert.py
1312#[ derive( Debug , Clone , PartialEq , Deserialize ) ]
@@ -38,7 +37,7 @@ pub struct ModernBertConfig {
3837 pub mlp_bias : Option < bool > ,
3938 pub mlp_dropout : Option < f64 > ,
4039 pub decoder_bias : Option < bool > ,
41- pub classifier_pooling : Option < String > ,
40+ pub classifier_pooling : Option < Pool > ,
4241 pub classifier_dropout : Option < f64 > ,
4342 pub classifier_bias : Option < bool > ,
4443 pub classifier_activation : HiddenAct ,
@@ -485,11 +484,7 @@ impl ModernBertModel {
485484 pub fn load ( vb : VarBuilder , config : & ModernBertConfig , model_type : ModelType ) -> Result < Self > {
486485 let ( pool, classifier) = match model_type {
487486 ModelType :: Classifier => {
488- let pool: Pool = config
489- . classifier_pooling
490- . as_deref ( )
491- . and_then ( |s| Pool :: from_str ( s) . ok ( ) )
492- . unwrap_or ( Pool :: Cls ) ;
487+ let pool: Pool = config. classifier_pooling . clone ( ) . unwrap_or ( Pool :: Cls ) ;
493488
494489 let classifier: Box < dyn ClassificationHead + Send > =
495490 Box :: new ( ModernBertClassificationHead :: load ( vb. clone ( ) , config) ?) ;
0 commit comments