| 
 | 1 | +from __future__ import print_function, division, absolute_import  | 
 | 2 | + | 
 | 3 | +import numpy as np  | 
 | 4 | +from scipy import interpolate  | 
 | 5 | + | 
 | 6 | +from astropy.modeling.core import FittableModel, Model  | 
 | 7 | +from astropy.modeling.functional_models import Shift  | 
 | 8 | +from astropy.modeling.parameters import Parameter  | 
 | 9 | +from astropy.modeling.utils import poly_map_domain, comb  | 
 | 10 | +from astropy.modeling.fitting import _FitterMeta, fitter_unit_support  | 
 | 11 | +from astropy.utils import indent, check_broadcast  | 
 | 12 | +from astropy.units import Quantity  | 
 | 13 | + | 
 | 14 | + | 
 | 15 | +__all__ = []  | 
 | 16 | + | 
 | 17 | +class SplineModel(FittableModel):  | 
 | 18 | +    """  | 
 | 19 | +    Wrapper around scipy.interpolate.splrep and splev  | 
 | 20 | +      | 
 | 21 | +    Analogous to scipy.interpolate.UnivariateSpline() if knots unspecified,  | 
 | 22 | +    and scipy.interpolate.LSQUnivariateSpline if knots are specified  | 
 | 23 | +      | 
 | 24 | +    There are two ways to make a spline model.  | 
 | 25 | +    (1) you have the spline auto-determine knots from the data  | 
 | 26 | +    (2) you specify the knots  | 
 | 27 | +      | 
 | 28 | +    """  | 
 | 29 | +      | 
 | 30 | +    linear = False # I think? I have no idea?  | 
 | 31 | +    col_fit_deriv = False # Not sure what this is  | 
 | 32 | +      | 
 | 33 | +    def __init__(self, degree=3, smoothing=None, knots=None, extrapolate_mode=0):  | 
 | 34 | +        """  | 
 | 35 | +        Set up a spline model.  | 
 | 36 | +          | 
 | 37 | +        degree: degree of the spline (default 3)  | 
 | 38 | +            In scipy fitpack, this is "k"  | 
 | 39 | +          | 
 | 40 | +        smoothing (optional): smoothing value for automatically determining knots  | 
 | 41 | +            In scipy fitpack, this is "s"  | 
 | 42 | +            By default, uses a   | 
 | 43 | +          | 
 | 44 | +        knots (optional): spline knots (boundaries of piecewise polynomial)  | 
 | 45 | +            If not specified, will automatically determine knots based on  | 
 | 46 | +            degree + smoothing  | 
 | 47 | +              | 
 | 48 | +        extrapolate_mode (optional): how to deal with solution outside of interval.  | 
 | 49 | +            (see scipy.interpolate.splev)  | 
 | 50 | +            if 0 (default): return the extrapolated value  | 
 | 51 | +            if 1, return 0  | 
 | 52 | +            if 2, raise a ValueError  | 
 | 53 | +            if 3, return the boundary value  | 
 | 54 | +        """  | 
 | 55 | +        self._degree = degree  | 
 | 56 | +        self._smoothing = smoothing  | 
 | 57 | +        self._knots = self.verify_knots(knots)  | 
 | 58 | +        self.extrapolate_mode = extrapolate_mode  | 
 | 59 | +          | 
 | 60 | +        ## This is used to evaluate the spline  | 
 | 61 | +        ## When None, raises an error when trying to evaluate the spline  | 
 | 62 | +        self._tck = None  | 
 | 63 | +          | 
 | 64 | +        self._param_names = ()  | 
 | 65 | +          | 
 | 66 | +    def verify_knots(self, knots):  | 
 | 67 | +        """  | 
 | 68 | +        Basic knot array vetting.  | 
 | 69 | +        The goal of having this is to enable more useful error messages  | 
 | 70 | +        than scipy (if needed).  | 
 | 71 | +        """  | 
 | 72 | +        if knots is None: return None  | 
 | 73 | +        knots = np.array(knots)  | 
 | 74 | +        assert len(knots.shape) == 1, knots.shape  | 
 | 75 | +        knots = np.sort(knots)  | 
 | 76 | +        assert len(np.unique(knots)) == len(knots), knots  | 
 | 77 | +        return knots  | 
 | 78 | +      | 
 | 79 | +    ############  | 
 | 80 | +    ## Getters  | 
 | 81 | +    ############  | 
 | 82 | +    def get_degree(self):  | 
 | 83 | +        """ Spline degree (k in FITPACK) """  | 
 | 84 | +        return self._degree  | 
 | 85 | +    def get_smoothing(self):  | 
 | 86 | +        """ Spline smoothing (s in FITPACK) """  | 
 | 87 | +        return self._smoothing  | 
 | 88 | +    def get_knots(self):  | 
 | 89 | +        """ Spline knots (t in FITPACK) """  | 
 | 90 | +        return self._knots  | 
 | 91 | +    def get_coeffs(self):  | 
 | 92 | +        """ Spline coefficients (c in FITPACK) """  | 
 | 93 | +        if self._tck is not None:  | 
 | 94 | +            return self._tck[1]  | 
 | 95 | +        else:  | 
 | 96 | +            raise RuntimeError("SplineModel has not been fit yet")  | 
 | 97 | +      | 
 | 98 | +    ############  | 
 | 99 | +    ## Spline methods: not tested at all  | 
 | 100 | +    ############  | 
 | 101 | +    def derivative(self, n=1):  | 
 | 102 | +        if self._tck is None:  | 
 | 103 | +            raise RuntimeError("SplineModel has not been fit yet")  | 
 | 104 | +        else:  | 
 | 105 | +            t, c, k = self._tck  | 
 | 106 | +            return scipy.interpolate.BSpline.construct_fast(  | 
 | 107 | +                t,c,k,extrapolate=(self.extrapolate_mode==0)).derivative(n)  | 
 | 108 | +    def antiderivative(self, n=1):  | 
 | 109 | +        if self._tck is None:  | 
 | 110 | +            raise RuntimeError("SplineModel has not been fit yet")  | 
 | 111 | +        else:  | 
 | 112 | +            t, c, k = self._tck  | 
 | 113 | +            return scipy.interpolate.BSpline.construct_fast(  | 
 | 114 | +                t,c,k,extrapolate=(self.extrapolate_mode==0)).antiderivative(n)  | 
 | 115 | +    def integral(self, a, b):  | 
 | 116 | +        if self._tck is None:  | 
 | 117 | +            raise RuntimeError("SplineModel has not been fit yet")  | 
 | 118 | +        else:  | 
 | 119 | +            t, c, k = self._tck  | 
 | 120 | +            return scipy.interpolate.BSpline.construct_fast(  | 
 | 121 | +                t,c,k,extrapolate=(self.extrapolate_mode==0)).integral(a,b)  | 
 | 122 | +    def derivatives(self, x):  | 
 | 123 | +        raise NotImplementedError  | 
 | 124 | +    def roots(self):  | 
 | 125 | +        raise NotImplementedError  | 
 | 126 | +      | 
 | 127 | +    ############  | 
 | 128 | +    ## Setters: not really implemented or tested  | 
 | 129 | +    ############  | 
 | 130 | +    def reset_model(self):  | 
 | 131 | +        """ Resets model so it needs to be refit to be valid """  | 
 | 132 | +        self._tck = None  | 
 | 133 | +    def set_degree(self, degree):  | 
 | 134 | +        """ Spline degree (k in FITPACK) """  | 
 | 135 | +        raise NotImplementedError  | 
 | 136 | +        self._degree = degree  | 
 | 137 | +        self.reset_model()  | 
 | 138 | +    def set_smoothing(self, smoothing):  | 
 | 139 | +        """ Spline smoothing (s in FITPACK) """  | 
 | 140 | +        raise NotImplementedError  | 
 | 141 | +        self._smoothing = smoothing  | 
 | 142 | +        self.reset_model()  | 
 | 143 | +    def set_knots(self, knots):  | 
 | 144 | +        """ Spline knots (t in FITPACK) """  | 
 | 145 | +        raise NotImplementedError  | 
 | 146 | +        self._knots = self.verify_knots(knots)  | 
 | 147 | +        self.reset_model()  | 
 | 148 | +      | 
 | 149 | +    def set_model_from_tck(self, tck):  | 
 | 150 | +        """  | 
 | 151 | +        Use output of scipy.interpolate.splrep  | 
 | 152 | +        """  | 
 | 153 | +        self._tck = tck  | 
 | 154 | + | 
 | 155 | +    def __call__(self, x, der=0):  | 
 | 156 | +        """  | 
 | 157 | +        Evaluate the model with the given inputs.  | 
 | 158 | +        der is passed to scipy.interpolate.splev  | 
 | 159 | +        """  | 
 | 160 | +        if self._tck is None:  | 
 | 161 | +            raise RuntimeError("SplineModel has not been fit yet")  | 
 | 162 | +        return interpolate.splev(x, self._tck, der=der, ext=self.extrapolate_mode)  | 
 | 163 | +      | 
 | 164 | +    ####################################  | 
 | 165 | +    ######### Stuff below here is stubs  | 
 | 166 | +    @property  | 
 | 167 | +    def param_names(self):  | 
 | 168 | +        """  | 
 | 169 | +        Coefficient names generated based on the model's knots and polynomial degree.  | 
 | 170 | +        Not Implemented  | 
 | 171 | +        """  | 
 | 172 | +        raise NotImplementedError("SplineModel does not currently expose parameters")  | 
 | 173 | +        return self._param_names  | 
 | 174 | + | 
 | 175 | +    #def __getattr__(self, attr):  | 
 | 176 | +    #    """  | 
 | 177 | +    #    Fails right now. Future code:  | 
 | 178 | +    #    # From astropy.modeling.polynomial.PolynomialBase  | 
 | 179 | +    #    if self._param_names and attr in self._param_names:  | 
 | 180 | +    #        return Parameter(attr, default=0.0, model=self)  | 
 | 181 | +    #    raise AttributeError(attr)  | 
 | 182 | +    #    """  | 
 | 183 | +    #    raise NotImplementedError("SplineModel does not currently expose parameters")  | 
 | 184 | + | 
 | 185 | +    #def __setattr__(self, attr, value):  | 
 | 186 | +    #    """  | 
 | 187 | +    #    Fails right now. Future code:  | 
 | 188 | +    #    # From astropy.modeling.polynomial.PolynomialBase  | 
 | 189 | +    #    if attr[0] != '_' and self._param_names and attr in self._param_names:  | 
 | 190 | +    #        param = Parameter(attr, default=0.0, model=self)  | 
 | 191 | +    #        param.__set__(self, value)  | 
 | 192 | +    #    else:  | 
 | 193 | +    #        super().__setattr__(attr, value)  | 
 | 194 | +    #    """  | 
 | 195 | +    #    raise NotImplementedError("SplineModel does not currently expose parameters")  | 
 | 196 | + | 
 | 197 | +    def _generate_coeff_names(self):  | 
 | 198 | +        names = []  | 
 | 199 | +        degree, Nknots = self._degree, len(self._knots)  | 
 | 200 | +        for i in range(Nknots):  | 
 | 201 | +            for j in range(degree+1):  | 
 | 202 | +                names.append("k{}_c{}".format(i,j))  | 
 | 203 | +        return tuple(names)  | 
 | 204 | +      | 
 | 205 | +    def evaluate(self, *args, **kwargs):  | 
 | 206 | +        return self(*args, **kwargs)  | 
 | 207 | + | 
 | 208 | +          | 
 | 209 | + | 
 | 210 | +class SplineFitter(metaclass=_FitterMeta):  | 
 | 211 | +    """  | 
 | 212 | +    Run a spline fit.  | 
 | 213 | +    """  | 
 | 214 | +    def __init__(self):  | 
 | 215 | +        self.fit_info = {"fp": None,  | 
 | 216 | +                         "ier": None,  | 
 | 217 | +                         "msg": None}  | 
 | 218 | +        super().__init__()  | 
 | 219 | +      | 
 | 220 | +    def validate_model(self, model):  | 
 | 221 | +        if not isinstance(model, SplineModel):  | 
 | 222 | +            raise ValueError("model must be of type SplineModel (currently is {})".format(  | 
 | 223 | +                    type(model)))  | 
 | 224 | +      | 
 | 225 | +    ## TODO do something about units  | 
 | 226 | +    #@fitter_unit_support  | 
 | 227 | +    def __call__(self, model, x, y, w=None):  | 
 | 228 | +        """  | 
 | 229 | +        Fit a spline model to data.  | 
 | 230 | +        Internally uses scipy.interpolate.splrep.  | 
 | 231 | +          | 
 | 232 | +        """  | 
 | 233 | +          | 
 | 234 | +        self.validate_model(model)  | 
 | 235 | +          | 
 | 236 | +        ## Case (1): fit smoothing spline  | 
 | 237 | +        if model.get_knots() is None:  | 
 | 238 | +            tck, fp, ier, msg = interpolate.splrep(x, y, w=w,  | 
 | 239 | +                                                   t=None,  | 
 | 240 | +                                                   k=model.get_degree(),   | 
 | 241 | +                                                   s=model.get_smoothing(),  | 
 | 242 | +                                                   task=0, full_output=True  | 
 | 243 | +                                                   )  | 
 | 244 | +        ## Case (2): leastsq spline  | 
 | 245 | +        else:  | 
 | 246 | +            knots = model.get_knots()  | 
 | 247 | +            ## TODO some sort of validation that the knots are internal, since  | 
 | 248 | +            ## this procedure automatically adds knots at the two endpoints  | 
 | 249 | +            tck, fp, ier, msg = interpolate.splrep(x, y, w=w,  | 
 | 250 | +                                                   t=knots,  | 
 | 251 | +                                                   k=model.get_degree(),   | 
 | 252 | +                                                   s=model.get_smoothing(),  | 
 | 253 | +                                                   task=-1, full_output=True  | 
 | 254 | +                                                   )  | 
 | 255 | +          | 
 | 256 | +        model.set_model_from_tck(tck)  | 
 | 257 | +        self.fit_info.update({"fp":fp, "ier":ier, "msg":msg})  | 
 | 258 | +      | 
0 commit comments