Skip to content

Commit a997774

Browse files
committed
MNT: update changelog and apply changes suggested in review
1 parent 664825f commit a997774

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Attention: The newest changes should be on top -->
4141

4242
### Changed
4343

44-
-
44+
- MNT: move piecewise functions to separate file [#746](https://github.com/RocketPy-Team/RocketPy/pull/746)
4545

4646
### Fixed
4747

rocketpy/mathutils/piecewise_function.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,20 @@ def __new__(
4646
if outputs is None:
4747
outputs = ["Scalar"]
4848

49-
input_data = []
50-
output_data = []
51-
for interval in sorted(source.keys()):
52-
grid = np.linspace(interval[0], interval[1], datapoints)
49+
input_data = np.array([])
50+
output_data = np.array([])
51+
for lower, upper in sorted(source.keys()):
52+
grid = np.linspace(lower, upper, datapoints)
5353

5454
# since intervals are disjoint and sorted, we only need to check
5555
# if the first point is already included
56-
if interval[0] in input_data:
57-
grid = np.delete(grid, 0)
56+
if input_data.size != 0:
57+
if lower == input_data[-1]:
58+
grid = np.delete(grid, 0)
5859
input_data = np.concatenate((input_data, grid))
5960

60-
f = Function(source[interval])
61-
output_data = np.concatenate((output_data, f(grid)))
61+
f = Function(source[(lower, upper)])
62+
output_data = np.concatenate((output_data, f.get_value(grid)))
6263

6364
return Function(
6465
np.concatenate(([input_data], [output_data])).T,
@@ -86,8 +87,8 @@ def __validate__source(source):
8687
if not isinstance(key, tuple):
8788
raise TypeError("keys of source must be tuples")
8889
# Check if all domains are disjoint
89-
for interval1 in source.keys():
90-
for interval2 in source.keys():
91-
if interval1 != interval2:
92-
if interval1[0] < interval2[1] and interval1[1] > interval2[0]:
90+
for lower1, upper1 in source.keys():
91+
for lower2, upper2 in source.keys():
92+
if (lower1, upper1) != (lower2, upper2):
93+
if lower1 < upper2 and upper1 > lower2:
9394
raise ValueError("domains must be disjoint")

0 commit comments

Comments
 (0)