@@ -199,15 +199,43 @@ def spline1d(x, knots, degree=3):
199199 return X
200200
201201
202- def _make_A_cartesian (x , y , n_knots = 10 , radius = 3.0 , knot_spacing_type = "sqrt" , degree = 3 ):
203- if knot_spacing_type == "sqrt" :
204- knots = np .linspace (- np .sqrt (radius ), np .sqrt (radius ), n_knots )
205- knots = np .sign (knots ) * knots ** 2
202+ def _make_A_cartesian (x , y , n_knots = 10 , radius = 3.0 , spacing = "sqrt" , degree = 3 ):
203+ # Must be odd
204+ n_odd_knots = n_knots if n_knots % 2 == 1 else n_knots + 1
205+ if spacing == "sqrt" :
206+ x_knots = np .linspace (- np .sqrt (radius ), np .sqrt (radius ), n_odd_knots )
207+ x_knots = np .sign (x_knots ) * x_knots ** 2
208+ y_knots = np .linspace (- np .sqrt (radius ), np .sqrt (radius ), n_odd_knots )
209+ y_knots = np .sign (y_knots ) * y_knots ** 2
206210 else :
207- knots = np .linspace (- radius , radius , n_knots )
208- x_spline = spline1d (x , knots = knots , degree = degree )
209- y_spline = spline1d (y , knots = knots , degree = degree )
210-
211+ x_knots = np .linspace (- radius , radius , n_odd_knots )
212+ y_knots = np .linspace (- radius , radius , n_odd_knots )
213+ x_spline = sparse .csr_matrix (
214+ np .asarray (
215+ dmatrix (
216+ "bs(x, knots=knots, degree=degree, include_intercept=True)" ,
217+ {
218+ "x" : list (np .hstack ([x_knots .min (), x , x_knots .max ()])),
219+ "degree" : degree ,
220+ "knots" : x_knots ,
221+ },
222+ )
223+ )[1 :- 1 ]
224+ )
225+ y_spline = sparse .csr_matrix (
226+ np .asarray (
227+ dmatrix (
228+ "bs(x, knots=knots, degree=degree, include_intercept=True)" ,
229+ {
230+ "x" : list (np .hstack ([y_knots .min (), y , y_knots .max ()])),
231+ "degree" : degree ,
232+ "knots" : y_knots ,
233+ },
234+ )
235+ )[1 :- 1 ]
236+ )
237+ x_spline = x_spline [:, np .asarray (x_spline .sum (axis = 0 ))[0 ] != 0 ]
238+ y_spline = y_spline [:, np .asarray (y_spline .sum (axis = 0 ))[0 ] != 0 ]
211239 X = sparse .hstack (
212240 [x_spline .multiply (y_spline [:, idx ]) for idx in range (y_spline .shape [1 ])],
213241 format = "csr" ,
0 commit comments