@@ -125,6 +125,27 @@ static void replace_all(std::string & s, const std::string & search, const std::
125125 }
126126 s = std::move (result);
127127}
128+
129+ static bool is_float_close (float a, float b, float abs_tol) {
130+ // Check for non-negative tolerance
131+ if (abs_tol < 0.0 ) {
132+ throw std::invalid_argument (" Tolerance must be non-negative" );
133+ }
134+
135+ // Exact equality check
136+ if (a == b) {
137+ return true ;
138+ }
139+
140+ // Check for infinities
141+ if (std::isinf (a) || std::isinf (b)) {
142+ return false ;
143+ }
144+
145+ // Regular comparison using the provided absolute tolerance
146+ return std::fabs (b - a) <= abs_tol;
147+ }
148+
128149#ifdef GGML_USE_CPU_HBM
129150#include < hbwmalloc.h>
130151#endif
@@ -969,7 +990,24 @@ struct llama_hparams {
969990 float rope_freq_scale_train;
970991
971992 bool operator !=(const llama_hparams & other) const {
972- return static_cast <bool >(memcmp (this , &other, sizeof (llama_hparams))); // NOLINT
993+ if (this ->vocab_only != other.vocab_only ) return true ;
994+ if (this ->n_vocab != other.n_vocab ) return true ;
995+ if (this ->n_ctx_train != other.n_ctx_train ) return true ;
996+ if (this ->n_embd != other.n_embd ) return true ;
997+ if (this ->n_head != other.n_head ) return true ;
998+ if (this ->n_head_kv != other.n_head_kv ) return true ;
999+ if (this ->n_layer != other.n_layer ) return true ;
1000+ if (this ->n_rot != other.n_rot ) return true ;
1001+ if (this ->n_ff != other.n_ff ) return true ;
1002+
1003+ const float EPSILON = 1e-9 ;
1004+
1005+ if (!is_float_close (this ->f_norm_eps , other.f_norm_eps , EPSILON)) return true ;
1006+ if (!is_float_close (this ->f_norm_rms_eps , other.f_norm_rms_eps , EPSILON)) return true ;
1007+ if (!is_float_close (this ->rope_freq_base_train , other.rope_freq_base_train , EPSILON)) return true ;
1008+ if (!is_float_close (this ->rope_freq_scale_train , other.rope_freq_scale_train , EPSILON)) return true ;
1009+
1010+ return false ;
9731011 }
9741012
9751013 uint32_t n_gqa () const {
0 commit comments