@@ -40,62 +40,25 @@ def __new__(
40
40
The number of points in which the piecewise function will be
41
41
evaluated to create a base function. The default value is 100.
42
42
"""
43
+ cls .__validate__source (source )
43
44
if inputs is None :
44
45
inputs = ["Scalar" ]
45
46
if outputs is None :
46
47
outputs = ["Scalar" ]
47
- # Check if source is a dictionary
48
- if not isinstance (source , dict ):
49
- raise TypeError ("source must be a dictionary" )
50
- # Check if all keys are tuples
51
- for key in source .keys ():
52
- if not isinstance (key , tuple ):
53
- raise TypeError ("keys of source must be tuples" )
54
- # Check if all domains are disjoint
55
- for key1 in source .keys ():
56
- for key2 in source .keys ():
57
- if key1 != key2 :
58
- if key1 [0 ] < key2 [1 ] and key1 [1 ] > key2 [0 ]:
59
- raise ValueError ("domains must be disjoint" )
60
-
61
- # Crate Function
62
- def calc_output (func , inputs ):
63
- """Receives a list of inputs value and a function, populates another
64
- list with the results corresponding to the same results.
65
-
66
- Parameters
67
- ----------
68
- func : Function
69
- The Function object to be
70
- inputs : list, tuple, np.array
71
- The array of points to applied the func to.
72
-
73
- Examples
74
- --------
75
- >>> inputs = [0, 1, 2, 3, 4, 5]
76
- >>> def func(x):
77
- ... return x*10
78
- >>> calc_output(func, inputs)
79
- [0, 10, 20, 30, 40, 50]
80
-
81
- Notes
82
- -----
83
- In the future, consider using the built-in map function from python.
84
- """
85
- output = np .zeros (len (inputs ))
86
- for j , value in enumerate (inputs ):
87
- output [j ] = func .get_value_opt (value )
88
- return output
89
48
90
49
input_data = []
91
50
output_data = []
92
- for key in sorted (source .keys ()):
93
- i = np .linspace (key [0 ], key [1 ], datapoints )
94
- i = i [~ np .isin (i , input_data )]
95
- input_data = np .concatenate ((input_data , i ))
51
+ for interval in sorted (source .keys ()):
52
+ grid = np .linspace (interval [0 ], interval [1 ], datapoints )
53
+
54
+ # since intervals are disjoint and sorted, we only need to check
55
+ # if the first point is already included
56
+ if interval [0 ] in input_data :
57
+ grid = np .delete (grid , 0 )
58
+ input_data = np .concatenate ((input_data , grid ))
96
59
97
- f = Function (source [key ])
98
- output_data = np .concatenate ((output_data , calc_output ( f , i )))
60
+ f = Function (source [interval ])
61
+ output_data = np .concatenate ((output_data , f ( grid )))
99
62
100
63
return Function (
101
64
np .concatenate (([input_data ], [output_data ])).T ,
@@ -104,3 +67,27 @@ def calc_output(func, inputs):
104
67
interpolation = interpolation ,
105
68
extrapolation = extrapolation ,
106
69
)
70
+
71
+ @staticmethod
72
+ def __validate__source (source ):
73
+ """Validates that source is dictionary with non-overlapping
74
+ intervals
75
+
76
+ Parameters
77
+ ----------
78
+ source : dict
79
+ A dictionary of Function objects, where the keys are the domains.
80
+ """
81
+ # Check if source is a dictionary
82
+ if not isinstance (source , dict ):
83
+ raise TypeError ("source must be a dictionary" )
84
+ # Check if all keys are tuples
85
+ for key in source .keys ():
86
+ if not isinstance (key , tuple ):
87
+ raise TypeError ("keys of source must be tuples" )
88
+ # Check if all domains are disjoint
89
+ for interval1 in source .keys ():
90
+ for interval2 in source .keys ():
91
+ if interval1 != interval2 :
92
+ if interval1 [0 ] < interval2 [1 ] and interval1 [1 ] > interval2 [0 ]:
93
+ raise ValueError ("domains must be disjoint" )
0 commit comments