@@ -46,19 +46,20 @@ def __new__(
46
46
if outputs is None :
47
47
outputs = ["Scalar" ]
48
48
49
- input_data = []
50
- output_data = []
51
- for interval in sorted (source .keys ()):
52
- grid = np .linspace (interval [ 0 ], interval [ 1 ] , datapoints )
49
+ input_data = np . array ([])
50
+ output_data = np . array ([])
51
+ for lower , upper in sorted (source .keys ()):
52
+ grid = np .linspace (lower , upper , datapoints )
53
53
54
54
# since intervals are disjoint and sorted, we only need to check
55
55
# if the first point is already included
56
- if interval [0 ] in input_data :
57
- grid = np .delete (grid , 0 )
56
+ if input_data .size != 0 :
57
+ if lower == input_data [- 1 ]:
58
+ grid = np .delete (grid , 0 )
58
59
input_data = np .concatenate ((input_data , grid ))
59
60
60
- f = Function (source [interval ])
61
- output_data = np .concatenate ((output_data , f (grid )))
61
+ f = Function (source [( lower , upper ) ])
62
+ output_data = np .concatenate ((output_data , f . get_value (grid )))
62
63
63
64
return Function (
64
65
np .concatenate (([input_data ], [output_data ])).T ,
@@ -86,8 +87,8 @@ def __validate__source(source):
86
87
if not isinstance (key , tuple ):
87
88
raise TypeError ("keys of source must be tuples" )
88
89
# Check if all domains are disjoint
89
- for interval1 in source .keys ():
90
- for interval2 in source .keys ():
91
- if interval1 != interval2 :
92
- if interval1 [ 0 ] < interval2 [ 1 ] and interval1 [ 1 ] > interval2 [ 0 ] :
90
+ for lower1 , upper1 in source .keys ():
91
+ for lower2 , upper2 in source .keys ():
92
+ if ( lower1 , upper1 ) != ( lower2 , upper2 ) :
93
+ if lower1 < upper2 and upper1 > lower2 :
93
94
raise ValueError ("domains must be disjoint" )
0 commit comments