Skip to content

Commit 3e9c499

Browse files
Update spline to be slightly more robust
Extracted out of #54, this is just a slightly more robust version of this part of the code.
1 parent 846f17a commit 3e9c499

File tree

1 file changed

+36
-8
lines changed

1 file changed

+36
-8
lines changed

src/psfmachine/utils.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,43 @@ def spline1d(x, knots, degree=3):
199199
return X
200200

201201

202-
def _make_A_cartesian(x, y, n_knots=10, radius=3.0, knot_spacing_type="sqrt", degree=3):
203-
if knot_spacing_type == "sqrt":
204-
knots = np.linspace(-np.sqrt(radius), np.sqrt(radius), n_knots)
205-
knots = np.sign(knots) * knots ** 2
202+
def _make_A_cartesian(x, y, n_knots=10, radius=3.0, spacing="sqrt", degree=3):
203+
# Must be odd
204+
n_odd_knots = n_knots if n_knots % 2 == 1 else n_knots + 1
205+
if spacing == "sqrt":
206+
x_knots = np.linspace(-np.sqrt(radius), np.sqrt(radius), n_odd_knots)
207+
x_knots = np.sign(x_knots) * x_knots ** 2
208+
y_knots = np.linspace(-np.sqrt(radius), np.sqrt(radius), n_odd_knots)
209+
y_knots = np.sign(y_knots) * y_knots ** 2
206210
else:
207-
knots = np.linspace(-radius, radius, n_knots)
208-
x_spline = spline1d(x, knots=knots, degree=degree)
209-
y_spline = spline1d(y, knots=knots, degree=degree)
210-
211+
x_knots = np.linspace(-radius, radius, n_odd_knots)
212+
y_knots = np.linspace(-radius, radius, n_odd_knots)
213+
x_spline = sparse.csr_matrix(
214+
np.asarray(
215+
dmatrix(
216+
"bs(x, knots=knots, degree=degree, include_intercept=True)",
217+
{
218+
"x": list(np.hstack([x_knots.min(), x, x_knots.max()])),
219+
"degree": degree,
220+
"knots": x_knots,
221+
},
222+
)
223+
)[1:-1]
224+
)
225+
y_spline = sparse.csr_matrix(
226+
np.asarray(
227+
dmatrix(
228+
"bs(x, knots=knots, degree=degree, include_intercept=True)",
229+
{
230+
"x": list(np.hstack([y_knots.min(), y, y_knots.max()])),
231+
"degree": degree,
232+
"knots": y_knots,
233+
},
234+
)
235+
)[1:-1]
236+
)
237+
x_spline = x_spline[:, np.asarray(x_spline.sum(axis=0))[0] != 0]
238+
y_spline = y_spline[:, np.asarray(y_spline.sum(axis=0))[0] != 0]
211239
X = sparse.hstack(
212240
[x_spline.multiply(y_spline[:, idx]) for idx in range(y_spline.shape[1])],
213241
format="csr",

0 commit comments

Comments
 (0)