Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hat_matrix_ not available for BasisSmoother #595

Open
mynanshan opened this issue Dec 28, 2023 · 2 comments
Open

hat_matrix_ not available for BasisSmoother #595

mynanshan opened this issue Dec 28, 2023 · 2 comments
Assignees
Labels

Comments

@mynanshan
Copy link

Describe the bug
Setting return_basis=True in BasisSmoother leads to

AttributeError: 'BasisSmoother' object has no attribute 'hat_matrix_'

With a further look into the codes, I found that in _LinearSmoother the hat_matrix_ attribute is created inside the fit method with the line

self.hat_matrix_ = self.hat_matrix()

Meanwhile, that line of code does not appear in the fit method of the BasisSmoother.

To Reproduce
Code to reproduce the behavior:

Forgive me for copying the long code to generate data for the example:

import numpy as np
from scipy.integrate import solve_ivp

def simple_linear_ode(t, x, A):
    dx = A @ x
    return dx

t0 = 0; t1 = 1
m = 100
t = np.linspace(t0, t1, m+1)

n = 10
A = np.zeros((n,n))
for k in range(int(n / 2)):
    A[2*k, 2*k+1] = 2 * np.pi * (k+1)
    A[2*k+1, 2*k] = -2 * np.pi * (k+1)


x0 = np.random.normal(0, 0.5, n)
ode_solution = solve_ivp(simple_linear_ode, t_span=(t0, t1), y0=x0, t_eval=t, args=(A,))
x = ode_solution.y
snr = 5
noise_sd = np.std(x, axis=1) / snr
y = x + np.random.normal(0, noise_sd[:, np.newaxis], x.shape)

And the skfda part starts from here:

from skfda import FDataGrid
from skfda.representation.basis import BSplineBasis
from skfda.preprocessing.smoothing import BasisSmoother
from skfda.misc.regularization import L2Regularization
from skfda.misc.operators import LinearDifferentialOperator
from skfda.preprocessing.smoothing.validation import SmoothingParameterSearch
from skfda.preprocessing.smoothing.validation import LinearSmootherGeneralizedCVScorer

basis_degree = 3
basis_knots = np.linspace(t0, t1, int(m/2)+1)
basis = BSplineBasis(domain_range=(t0,t1), order=basis_degree+1, knots=basis_knots)

yfd = FDataGrid(data_matrix=y, grid_points=t)

smoother = SmoothingParameterSearch(
    BasisSmoother(basis, regularization=L2Regularization(LinearDifferentialOperator(2)), return_basis=True),
    param_values=np.exp(np.arange(-10,0,0.5)),
    param_name="smoothing_parameter",
    scoring=LinearSmootherGeneralizedCVScorer()
)

smoother.fit(yfd)
print(f"The best smoothing parameter is {smoother.best_params_}")
yfd_smooth = smoother.transform(yfd)

Error messages

[.../lib/python3.11/site-packages/sklearn/model_selection/_validation.py:821](.../lib/python3.11/site-packages/sklearn/model_selection/_validation.py:821): UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: 
Traceback (most recent call last):
  File ".../lib/python3.11/site-packages/sklearn/model_selection/_validation.py", line 810, in _score
    scores = scorer(estimator, X_test, y_test)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/skfda/preprocessing/smoothing/validation.py", line 129, in __call__
    y_est, hat_matrix = _get_input_estimation_and_matrix(estimator, X)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/skfda/preprocessing/smoothing/validation.py", line 26, in _get_input_estimation_and_matrix
    hat_matrix = estimator.hat_matrix_
                 ^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'BasisSmoother' object has no attribute 'hat_matrix_'

  warnings.warn(
The best smoothing parameter is {'smoothing_parameter': 4.5399929762484854e-05}
[/.../lib/python3.11/site-packages/sklearn/model_selection/_search.py:979](.../lib/python3.11/site-packages/sklearn/model_selection/_search.py:979): UserWarning: One or more of the test scores are non-finite: [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan]
  warnings.warn(

Version information

  • OS: Ubuntu 22.04 LTS
  • Python version: 3.11.6 | packaged by conda-forge | (main, Oct 3 2023, 10:40:35) [GCC 12.3.0]
  • scikit-fda version: 0.9
  • numpy: 1.26.0
@mynanshan mynanshan added the bug label Dec 28, 2023
@vnmabus
Copy link
Member

vnmabus commented Aug 17, 2024

Sorry for taking too long to answer. I confirmed that the bug is reproducible, and I will try to fix it as soon as I can.

@all-contributors please add @mynanshan for bug reports.

@vnmabus vnmabus self-assigned this Aug 17, 2024
Copy link
Contributor

@vnmabus

I couldn't determine any contributions to add, did you specify any contributions?
Please make sure to use valid contribution names.

I've put up a pull request to add @mynanshan! 🎉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants