@@ -603,8 +603,9 @@ def __init__(
603603
604604 rope_theta = getattr (config , "rope_theta" , 10000 )
605605
606- head_dim = getattr (config , "head_dim" ,
607- config .hidden_size // config .num_attention_heads )
606+ head_dim = getattr (config , "head_dim" , None )
607+ if head_dim is None :
608+ head_dim = config .hidden_size // config .num_attention_heads
608609 if hasattr (config , "max_model_len" ) and isinstance (
609610 config .max_model_len , int ):
610611 max_position_embeddings = min (config .max_position_embeddings ,
@@ -860,8 +861,9 @@ def layer_fn(prefix):
860861 cache_shape = self .cache_shape )
861862
862863 rope_theta = getattr (config , "rope_theta" , 10000 )
863- head_dim = getattr (config , "head_dim" ,
864- config .hidden_size // config .num_attention_heads )
864+ head_dim = getattr (config , "head_dim" , None )
865+ if head_dim is None :
866+ head_dim = config .hidden_size // config .num_attention_heads
865867 if hasattr (config , "max_model_len" ) and isinstance (
866868 config .max_model_len , int ):
867869 max_position_embeddings = min (config .max_position_embeddings ,
0 commit comments