@@ -322,19 +322,26 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
322322 let ( n, _) = x. shape ( ) ;
323323 let mut y_hat: Vec < TX > = Array1 :: zeros ( n) ;
324324
325+ let mut row = Vec :: with_capacity ( n) ;
325326 for i in 0 ..n {
326- let row_pred: TX =
327- self . predict_for_row ( Vec :: from_iterator ( x. get_row ( i) . iterator ( 0 ) . copied ( ) , n) ) ;
327+ row. clear ( ) ;
328+ row. extend ( x. get_row ( i) . iterator ( 0 ) . copied ( ) ) ;
329+ let row_pred: TX = self . predict_for_row ( & row) ;
328330 y_hat. set ( i, row_pred) ;
329331 }
330332
331333 Ok ( y_hat)
332334 }
333335
334- fn predict_for_row ( & self , x : Vec < TX > ) -> TX {
336+ fn predict_for_row ( & self , x : & [ TX ] ) -> TX {
335337 let mut f = self . b . unwrap ( ) ;
336338
339+ let xi: Vec < _ > = x. iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
337340 for i in 0 ..self . instances . as_ref ( ) . unwrap ( ) . len ( ) {
341+ let xj: Vec < _ > = self . instances . as_ref ( ) . unwrap ( ) [ i]
342+ . iter ( )
343+ . map ( |e| e. to_f64 ( ) . unwrap ( ) )
344+ . collect ( ) ;
338345 f += self . w . as_ref ( ) . unwrap ( ) [ i]
339346 * TX :: from (
340347 self . parameters
@@ -343,13 +350,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
343350 . kernel
344351 . as_ref ( )
345352 . unwrap ( )
346- . apply (
347- & x. iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
348- & self . instances . as_ref ( ) . unwrap ( ) [ i]
349- . iter ( )
350- . map ( |e| e. to_f64 ( ) . unwrap ( ) )
351- . collect ( ) ,
352- )
353+ . apply ( & xi, & xj)
353354 . unwrap ( ) ,
354355 )
355356 . unwrap ( ) ;
@@ -472,14 +473,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
472473 let tol = self . parameters . tol ;
473474 let good_enough = TX :: from_i32 ( 1000 ) . unwrap ( ) ;
474475
476+ let mut x = Vec :: with_capacity ( n) ;
475477 for _ in 0 ..self . parameters . epoch {
476478 for i in self . permutate ( n) {
477- self . process (
478- i,
479- Vec :: from_iterator ( self . x . get_row ( i) . iterator ( 0 ) . copied ( ) , n) ,
480- * self . y . get ( i) ,
481- & mut cache,
482- ) ;
479+ x. clear ( ) ;
480+ x. extend ( self . x . get_row ( i) . iterator ( 0 ) . take ( n) . copied ( ) ) ;
481+ self . process ( i, & x, * self . y . get ( i) , & mut cache) ;
483482 loop {
484483 self . reprocess ( tol, & mut cache) ;
485484 self . find_min_max_gradient ( ) ;
@@ -511,24 +510,17 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
511510 let mut cp = 0 ;
512511 let mut cn = 0 ;
513512
513+ let mut x = Vec :: with_capacity ( n) ;
514514 for i in self . permutate ( n) {
515+ x. clear ( ) ;
516+ x. extend ( self . x . get_row ( i) . iterator ( 0 ) . take ( n) . copied ( ) ) ;
515517 if * self . y . get ( i) == TY :: one ( ) && cp < few {
516- if self . process (
517- i,
518- Vec :: from_iterator ( self . x . get_row ( i) . iterator ( 0 ) . copied ( ) , n) ,
519- * self . y . get ( i) ,
520- cache,
521- ) {
518+ if self . process ( i, & x, * self . y . get ( i) , cache) {
522519 cp += 1 ;
523520 }
524521 } else if * self . y . get ( i) == TY :: from ( -1 ) . unwrap ( )
525522 && cn < few
526- && self . process (
527- i,
528- Vec :: from_iterator ( self . x . get_row ( i) . iterator ( 0 ) . copied ( ) , n) ,
529- * self . y . get ( i) ,
530- cache,
531- )
523+ && self . process ( i, & x, * self . y . get ( i) , cache)
532524 {
533525 cn += 1 ;
534526 }
@@ -539,7 +531,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
539531 }
540532 }
541533
542- fn process ( & mut self , i : usize , x : Vec < TX > , y : TY , cache : & mut Cache < TX , TY , X , Y > ) -> bool {
534+ fn process ( & mut self , i : usize , x : & [ TX ] , y : TY , cache : & mut Cache < TX , TY , X , Y > ) -> bool {
543535 for j in 0 ..self . sv . len ( ) {
544536 if self . sv [ j] . index == i {
545537 return true ;
@@ -551,15 +543,14 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
551543 let mut cache_values: Vec < ( ( usize , usize ) , TX ) > = Vec :: new ( ) ;
552544
553545 for v in self . sv . iter ( ) {
546+ let xi: Vec < _ > = v. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
547+ let xj: Vec < _ > = x. iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
554548 let k = self
555549 . parameters
556550 . kernel
557551 . as_ref ( )
558552 . unwrap ( )
559- . apply (
560- & v. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
561- & x. iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
562- )
553+ . apply ( & xi, & xj)
563554 . unwrap ( ) ;
564555 cache_values. push ( ( ( i, v. index ) , TX :: from ( k) . unwrap ( ) ) ) ;
565556 g -= v. alpha * k;
@@ -578,7 +569,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
578569 cache. insert ( v. 0 , v. 1 . to_f64 ( ) . unwrap ( ) ) ;
579570 }
580571
581- let x_f64 = x. iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
572+ let x_f64: Vec < _ > = x. iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
582573 let k_v = self
583574 . parameters
584575 . kernel
@@ -701,8 +692,10 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
701692 let km = sv1. k ;
702693 let gm = sv1. grad ;
703694 let mut best = 0f64 ;
695+ let xi: Vec < _ > = sv1. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
704696 for i in 0 ..self . sv . len ( ) {
705697 let v = & self . sv [ i] ;
698+ let xj: Vec < _ > = v. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
706699 let z = v. grad - gm;
707700 let k = cache. get (
708701 sv1,
@@ -711,10 +704,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
711704 . kernel
712705 . as_ref ( )
713706 . unwrap ( )
714- . apply (
715- & sv1. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
716- & v. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
717- )
707+ . apply ( & xi, & xj)
718708 . unwrap ( ) ,
719709 ) ;
720710 let mut curv = km + v. k - 2f64 * k;
@@ -732,6 +722,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
732722 }
733723 }
734724
725+ let xi: Vec < _ > = self . sv [ idx_1]
726+ . x
727+ . iter ( )
728+ . map ( |e| e. to_f64 ( ) . unwrap ( ) )
729+ . collect :: < Vec < _ > > ( ) ;
730+
735731 idx_2. map ( |idx_2| {
736732 (
737733 idx_1,
@@ -742,16 +738,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
742738 . as_ref ( )
743739 . unwrap ( )
744740 . apply (
745- & self . sv [ idx_1]
746- . x
747- . iter ( )
748- . map ( |e| e. to_f64 ( ) . unwrap ( ) )
749- . collect ( ) ,
741+ & xi,
750742 & self . sv [ idx_2]
751743 . x
752744 . iter ( )
753745 . map ( |e| e. to_f64 ( ) . unwrap ( ) )
754- . collect ( ) ,
746+ . collect :: < Vec < _ > > ( ) ,
755747 )
756748 . unwrap ( )
757749 } ) ,
@@ -765,8 +757,11 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
765757 let km = sv2. k ;
766758 let gm = sv2. grad ;
767759 let mut best = 0f64 ;
760+
761+ let xi: Vec < _ > = sv2. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
768762 for i in 0 ..self . sv . len ( ) {
769763 let v = & self . sv [ i] ;
764+ let xj: Vec < _ > = v. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
770765 let z = gm - v. grad ;
771766 let k = cache. get (
772767 sv2,
@@ -775,10 +770,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
775770 . kernel
776771 . as_ref ( )
777772 . unwrap ( )
778- . apply (
779- & sv2. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
780- & v. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
781- )
773+ . apply ( & xi, & xj)
782774 . unwrap ( ) ,
783775 ) ;
784776 let mut curv = km + v. k - 2f64 * k;
@@ -797,6 +789,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
797789 }
798790 }
799791
792+ let xj: Vec < _ > = self . sv [ idx_2]
793+ . x
794+ . iter ( )
795+ . map ( |e| e. to_f64 ( ) . unwrap ( ) )
796+ . collect ( ) ;
797+
800798 idx_1. map ( |idx_1| {
801799 (
802800 idx_1,
@@ -811,12 +809,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
811809 . x
812810 . iter ( )
813811 . map ( |e| e. to_f64 ( ) . unwrap ( ) )
814- . collect ( ) ,
815- & self . sv [ idx_2]
816- . x
817- . iter ( )
818- . map ( |e| e. to_f64 ( ) . unwrap ( ) )
819- . collect ( ) ,
812+ . collect :: < Vec < _ > > ( ) ,
813+ & xj,
820814 )
821815 . unwrap ( )
822816 } ) ,
@@ -835,12 +829,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
835829 . x
836830 . iter ( )
837831 . map ( |e| e. to_f64 ( ) . unwrap ( ) )
838- . collect ( ) ,
832+ . collect :: < Vec < _ > > ( ) ,
839833 & self . sv [ idx_2]
840834 . x
841835 . iter ( )
842836 . map ( |e| e. to_f64 ( ) . unwrap ( ) )
843- . collect ( ) ,
837+ . collect :: < Vec < _ > > ( ) ,
844838 )
845839 . unwrap ( ) ,
846840 ) ) ,
@@ -895,18 +889,18 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
895889 self . sv [ v1] . alpha -= step. to_f64 ( ) . unwrap ( ) ;
896890 self . sv [ v2] . alpha += step. to_f64 ( ) . unwrap ( ) ;
897891
892+ let xi_v1: Vec < _ > = self . sv [ v1] . x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
893+ let xi_v2: Vec < _ > = self . sv [ v2] . x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
898894 for i in 0 ..self . sv . len ( ) {
895+ let xj: Vec < _ > = self . sv [ i] . x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
899896 let k2 = cache. get (
900897 & self . sv [ v2] ,
901898 & self . sv [ i] ,
902899 self . parameters
903900 . kernel
904901 . as_ref ( )
905902 . unwrap ( )
906- . apply (
907- & self . sv [ v2] . x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
908- & self . sv [ i] . x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
909- )
903+ . apply ( & xi_v2, & xj)
910904 . unwrap ( ) ,
911905 ) ;
912906 let k1 = cache. get (
@@ -916,10 +910,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
916910 . kernel
917911 . as_ref ( )
918912 . unwrap ( )
919- . apply (
920- & self . sv [ v1] . x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
921- & self . sv [ i] . x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
922- )
913+ . apply ( & xi_v1, & xj)
923914 . unwrap ( ) ,
924915 ) ;
925916 self . sv [ i] . grad -= step. to_f64 ( ) . unwrap ( ) * ( k2 - k1) ;
0 commit comments