Skip to content

Commit 664825f

Browse files
committed
ENH: simplifying and optimizing the function, implementing tests.
1 parent 5033694 commit 664825f

File tree

2 files changed

+70
-48
lines changed

2 files changed

+70
-48
lines changed

rocketpy/mathutils/piecewise_function.py

+35-48
Original file line numberDiff line numberDiff line change
@@ -40,62 +40,25 @@ def __new__(
4040
The number of points in which the piecewise function will be
4141
evaluated to create a base function. The default value is 100.
4242
"""
43+
cls.__validate__source(source)
4344
if inputs is None:
4445
inputs = ["Scalar"]
4546
if outputs is None:
4647
outputs = ["Scalar"]
47-
# Check if source is a dictionary
48-
if not isinstance(source, dict):
49-
raise TypeError("source must be a dictionary")
50-
# Check if all keys are tuples
51-
for key in source.keys():
52-
if not isinstance(key, tuple):
53-
raise TypeError("keys of source must be tuples")
54-
# Check if all domains are disjoint
55-
for key1 in source.keys():
56-
for key2 in source.keys():
57-
if key1 != key2:
58-
if key1[0] < key2[1] and key1[1] > key2[0]:
59-
raise ValueError("domains must be disjoint")
60-
61-
# Crate Function
62-
def calc_output(func, inputs):
63-
"""Receives a list of inputs value and a function, populates another
64-
list with the results corresponding to the same results.
65-
66-
Parameters
67-
----------
68-
func : Function
69-
The Function object to be
70-
inputs : list, tuple, np.array
71-
The array of points to applied the func to.
72-
73-
Examples
74-
--------
75-
>>> inputs = [0, 1, 2, 3, 4, 5]
76-
>>> def func(x):
77-
... return x*10
78-
>>> calc_output(func, inputs)
79-
[0, 10, 20, 30, 40, 50]
80-
81-
Notes
82-
-----
83-
In the future, consider using the built-in map function from python.
84-
"""
85-
output = np.zeros(len(inputs))
86-
for j, value in enumerate(inputs):
87-
output[j] = func.get_value_opt(value)
88-
return output
8948

9049
input_data = []
9150
output_data = []
92-
for key in sorted(source.keys()):
93-
i = np.linspace(key[0], key[1], datapoints)
94-
i = i[~np.isin(i, input_data)]
95-
input_data = np.concatenate((input_data, i))
51+
for interval in sorted(source.keys()):
52+
grid = np.linspace(interval[0], interval[1], datapoints)
53+
54+
# since intervals are disjoint and sorted, we only need to check
55+
# if the first point is already included
56+
if interval[0] in input_data:
57+
grid = np.delete(grid, 0)
58+
input_data = np.concatenate((input_data, grid))
9659

97-
f = Function(source[key])
98-
output_data = np.concatenate((output_data, calc_output(f, i)))
60+
f = Function(source[interval])
61+
output_data = np.concatenate((output_data, f(grid)))
9962

10063
return Function(
10164
np.concatenate(([input_data], [output_data])).T,
@@ -104,3 +67,27 @@ def calc_output(func, inputs):
10467
interpolation=interpolation,
10568
extrapolation=extrapolation,
10669
)
70+
71+
@staticmethod
72+
def __validate__source(source):
73+
"""Validates that source is dictionary with non-overlapping
74+
intervals
75+
76+
Parameters
77+
----------
78+
source : dict
79+
A dictionary of Function objects, where the keys are the domains.
80+
"""
81+
# Check if source is a dictionary
82+
if not isinstance(source, dict):
83+
raise TypeError("source must be a dictionary")
84+
# Check if all keys are tuples
85+
for key in source.keys():
86+
if not isinstance(key, tuple):
87+
raise TypeError("keys of source must be tuples")
88+
# Check if all domains are disjoint
89+
for interval1 in source.keys():
90+
for interval2 in source.keys():
91+
if interval1 != interval2:
92+
if interval1[0] < interval2[1] and interval1[1] > interval2[0]:
93+
raise ValueError("domains must be disjoint")

tests/unit/test_piecewise_function.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import pytest
2+
3+
from rocketpy import PiecewiseFunction
4+
5+
6+
@pytest.mark.parametrize(
7+
"source",
8+
[
9+
((0, 4), lambda x: x),
10+
{"0-4": lambda x: x},
11+
{(0, 4): lambda x: x, (3, 5): lambda x: 2 * x},
12+
],
13+
)
14+
def test_invalid_source(source):
15+
"""Test an error is raised when the source parameter is invalid"""
16+
with pytest.raises((TypeError, ValueError)):
17+
PiecewiseFunction(source)
18+
19+
20+
@pytest.mark.parametrize(
21+
"source",
22+
[
23+
{(-1, 0): lambda x: -x, (0, 1): lambda x: x},
24+
{
25+
(0, 1): lambda x: x,
26+
(1, 2): lambda x: 1,
27+
(2, 3): lambda x: 3 - x,
28+
},
29+
],
30+
)
31+
@pytest.mark.parametrize("inputs", [None, "X"])
32+
@pytest.mark.parametrize("outputs", [None, "Y"])
33+
def test_new(source, inputs, outputs):
34+
"""Test if PiecewiseFunction.__new__ runs correctly"""
35+
PiecewiseFunction(source, inputs, outputs)

0 commit comments

Comments
 (0)