Skip to content

Commit 4e3cd4c

Browse files
committed
Improve the accuracy of sine_taylor with upcast FMA
1 parent f18a162 commit 4e3cd4c

File tree

3 files changed

+79
-46
lines changed

3 files changed

+79
-46
lines changed

functional_algorithms/expr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def toidentifier(value):
174174
elif isinstance(value, numpy.floating):
175175
try:
176176
intvalue = int(value)
177-
except OverflowError:
177+
except (OverflowError, ValueError):
178178
intvalue = None
179179
if value == intvalue:
180180
return value.dtype.kind + toidentifier(intvalue)

functional_algorithms/floating_point_algorithms.py

Lines changed: 53 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def split_veltkamp(ctx, x, C=None, scale=False):
222222
223223
"""
224224
if C is None:
225+
one = ctx.constant(1, x)
225226
C, N, invN = get_veltkamp_splitter_constants(ctx, get_largest(ctx, x))
226227
elif scale:
227228
one = ctx.constant(1, x)
@@ -382,14 +383,31 @@ def div_series(ctx, x, y):
382383
return x / y
383384

384385

386+
def _terms_add(ctx, terms, index, *operands):
387+
for i, v in enumerate(operands):
388+
if index + i < len(terms):
389+
if 1:
390+
terms[index + i] += v
391+
else:
392+
h, l = add_2sum(ctx, terms[index + i], v)
393+
terms[index + i] = h
394+
if index + i <= len(terms):
395+
terms.append(l)
396+
else:
397+
terms[index + i + 1] += l
398+
else:
399+
assert index + i == len(terms)
400+
terms.append(v)
401+
402+
385403
def _add_series_series(ctx, x, y):
386404

387405
def op(x, y):
388406
if x is None:
389407
return y
390408
if y is None:
391409
return x
392-
return x + y
410+
return add_2sum(ctx, x, y)
393411

394412
return _binaryop_series_series(ctx, x, y, op)
395413

@@ -401,7 +419,7 @@ def op(x, y):
401419
return -y
402420
if y is None:
403421
return x
404-
return x - y
422+
return add_2sum(ctx, x, -y)
405423

406424
return _binaryop_series_series(ctx, x, y, op)
407425

@@ -417,28 +435,30 @@ def _binaryop_series_series(ctx, x, y, op):
417435
assert sexp1 == sexp2, (sexp1, sexp2)
418436

419437
terms = []
420-
421438
for n in range(max(len(terms1), index1 - index2 + len(terms2))):
422439
k = n - (index1 - index2)
423440
if n < len(terms1):
424441
if k >= 0 and k < len(terms2):
425442
if swapped:
426-
terms.append(op(terms2[k], terms1[n]))
443+
r = op(terms2[k], terms1[n])
427444
else:
428-
terms.append(op(terms1[n], terms2[k]))
445+
r = op(terms1[n], terms2[k])
429446
else:
430447
if swapped:
431-
terms.append(op(None, terms1[n]))
448+
r = op(None, terms1[n])
432449
else:
433-
terms.append(op(terms1[n], None))
450+
r = op(terms1[n], None)
434451
elif k >= 0 and k < len(terms2):
435452
if swapped:
436-
terms.append(op(terms2[k], None))
453+
r = op(terms2[k], None)
437454
else:
438-
terms.append(op(None, terms2[k]))
455+
r = op(None, terms2[k])
439456
else:
440-
terms.append(ctx.constant(0, terms1[0]))
441-
457+
r = ctx.constant(0, terms1[0])
458+
if type(r) is tuple:
459+
_terms_add(ctx, terms, n, *r)
460+
else:
461+
_terms_add(ctx, terms, n, r)
442462
return ctx._series(tuple(terms), dict(unit_index=index1, scaling_exp=sexp1))
443463

444464

@@ -454,7 +474,7 @@ def add_series(ctx, x, y):
454474
return _add_series_series(ctx, x, ((0, 0), y))
455475
elif type(y) is tuple:
456476
return _add_series_series(ctx, ((0, 0), x), y)
457-
return x + y
477+
return _add_series_series(ctx, ((0, 0), x), ((0, 0), y))
458478

459479

460480
def subtract_series(ctx, x, y):
@@ -469,7 +489,7 @@ def subtract_series(ctx, x, y):
469489
return _subtract_series_series(ctx, x, ((0, 0), y))
470490
elif type(y) is tuple:
471491
return _subtract_series_series(ctx, ((0, 0), x), y)
472-
return x - y
492+
return _subtract_series_series(ctx, ((0, 0), x), ((0, 0), y))
473493

474494

475495
def mul_series_dekker(ctx, x, y, C=None):
@@ -484,22 +504,6 @@ def mul_series_dekker(ctx, x, y, C=None):
484504
x = ctx._get_series_operands(x)
485505
y = ctx._get_series_operands(y)
486506

487-
def terms_add(terms, index, *operands):
488-
for i, v in enumerate(operands):
489-
if index + i < len(terms):
490-
if 1:
491-
terms[index + i] += v
492-
else:
493-
h, l = add_2sum(ctx, terms[index + i], v)
494-
terms[index + i] = h
495-
if index + i <= len(terms):
496-
terms.append(l)
497-
else:
498-
terms[index + i + 1] += l
499-
else:
500-
assert index + i == len(terms)
501-
terms.append(v)
502-
503507
offset = 10000
504508

505509
if type(x) is tuple:
@@ -516,9 +520,9 @@ def terms_add(terms, index, *operands):
516520
for i, x_ in enumerate(x[1:]):
517521
for j, y_ in enumerate(y[1:]):
518522
if i + j >= offset:
519-
terms_add(terms, i + j, x_ * y_)
523+
_terms_add(ctx, terms, i + j, x_ * y_)
520524
else:
521-
terms_add(terms, i + j, *mul_dekker(ctx, x_, y_, C=C))
525+
_terms_add(ctx, terms, i + j, *mul_dekker(ctx, x_, y_, C=C))
522526
return ctx._series(tuple(terms), dict(unit_index=x[0][0] + y[0][0], scaling_exp=x[0][1]))
523527
else:
524528
# (x1, x2, ...) * y
@@ -529,17 +533,17 @@ def terms_add(terms, index, *operands):
529533
terms = []
530534
for i, x_ in enumerate(x[1:]):
531535
if i >= offset:
532-
terms_add(terms, i, x_ * y)
536+
_terms_add(ctx, terms, i, x_ * y)
533537
else:
534-
terms_add(terms, i, *mul_dekker(ctx, x_, y, C=C))
538+
_terms_add(ctx, terms, i, *mul_dekker(ctx, x_, y, C=C))
535539
return ctx._series(tuple(terms), dict(unit_index=x[0][0], scaling_exp=x[0][1]))
536540
elif type(y) is tuple:
537541
terms = []
538542
for i, y_ in enumerate(y[1:]):
539543
if i >= offset:
540-
terms_add(terms, i, x * y_)
544+
_terms_add(ctx, terms, i, x * y_)
541545
else:
542-
terms_add(terms, i, *mul_dekker(ctx, x, y_, C=C))
546+
_terms_add(ctx, terms, i, *mul_dekker(ctx, x, y_, C=C))
543547
return ctx._series(tuple(terms), dict(unit_index=y[0][0], scaling_exp=y[0][1]))
544548
return ctx._series(mul_dekker(ctx, x, y, C=C), dict(unit_index=0, scaling_exp=0))
545549

@@ -1356,7 +1360,20 @@ def sine_taylor(ctx, x, order=7, split=False):
13561360
C, f = [x], 1
13571361
for i in range(3, order + 1, 2):
13581362
f *= -i * (i - 1)
1359-
C.append(ctx.fma(x, ctx.constant(1 / f, x), zero))
1363+
if i >= 5:
1364+
C.append(ctx.fma(x, ctx.constant(1 / f, x), zero))
1365+
else:
1366+
# The following is required for float16 and float32 when f is
1367+
# small and early evaluation of 1/f leads to accuracy loss.
1368+
# For float64, there is a very minor accuracy loss.
1369+
#
1370+
# fh, fl = split_veltkamp(f)
1371+
# 1 / f = 1 / (fh + fl) = 1 / fh - fl / fh ** 2 + fl ** 2 / fh ** 3 - ...
1372+
# = [a = fl/fh] = 1 / fh - a / fh + a ** 2 / fh - ...
1373+
# ~ fma(-a, 1 / fh, 1 / fh)
1374+
fh, fl = split_veltkamp(ctx, ctx.constant(f, x))
1375+
a = fl / fh
1376+
C.append(ctx.fma(-a, x / fh, x / fh))
13601377
# Horner's scheme is most accurate
13611378
xx = ctx.fma(x, x, zero)
13621379
return fast_polynomial(ctx, xx, C, reverse=False, scheme=[None, horner_scheme, estrin_dac_scheme, canonical_scheme][1])

functional_algorithms/tests/test_floating_point_algorithms.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,9 @@ def test_sine_pade(dtype):
802802
show_ulp(ulp)
803803

804804

805-
@pytest.mark.parametrize("func,fma", [("sin", "upcast"), ("sin", "mul_add"), ("sin", "native"), ("numpy.sin", None)])
805+
@pytest.mark.parametrize(
806+
"func,fma", [("sin", "upcast"), ("sin", "mul_add"), ("sin_dekker", "native"), ("sin", "native"), ("numpy.sin", None)]
807+
)
806808
def test_sine_taylor(dtype, func, fma):
807809
import mpmath
808810
from collections import defaultdict
@@ -817,14 +819,26 @@ def test_sine_taylor(dtype, func, fma):
817819
mpctx = mpmath.mp
818820
for order in [optimal_order, 1, 3, 5, 7, 9, 11, 13, 17, 19][:1]:
819821

820-
@fa.targets.numpy.jit(
821-
paths=[fpa],
822-
dtype=dtype,
823-
debug=(1.5 if size <= 10 else 0),
824-
rewrite_parameters=dict(optimize_cast=False, fma_backend=fma),
825-
)
826-
def sin_func(ctx, x):
827-
return fpa.sine_taylor(ctx, x, order=order, split=False)
822+
if func == "sin_dekker":
823+
824+
@fa.targets.numpy.jit(
825+
paths=[fpa],
826+
dtype=dtype,
827+
debug=(1.5 if size <= 10 else 0),
828+
)
829+
def sin_dekker_func(ctx, x):
830+
return fpa.sine_taylor_dekker(ctx, x, order=order)
831+
832+
elif func == "sin":
833+
834+
@fa.targets.numpy.jit(
835+
paths=[fpa],
836+
dtype=dtype,
837+
debug=(1.5 if size <= 10 else 0),
838+
rewrite_parameters=dict(optimize_cast=False, fma_backend=fma),
839+
)
840+
def sin_func(ctx, x):
841+
return fpa.sine_taylor(ctx, x, order=order, split=False)
828842

829843
ulp = defaultdict(int)
830844
for x in samples:
@@ -833,6 +847,8 @@ def sin_func(ctx, x):
833847
sn = numpy.sin(x)
834848
elif func == "sin":
835849
sn = sin_func(x)
850+
elif func == "sin_dekker":
851+
sn = sin_dekker_func(x)
836852
else:
837853
assert 0, func # not implemented
838854
"""

0 commit comments

Comments
 (0)