Skip to content

Commit 72af3df

Browse files
committed
Improve the accuracy of sine_taylor with upcast FMA
1 parent f18a162 commit 72af3df

File tree

5 files changed

+86
-45
lines changed

5 files changed

+86
-45
lines changed

functional_algorithms/expr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def toidentifier(value):
174174
elif isinstance(value, numpy.floating):
175175
try:
176176
intvalue = int(value)
177-
except OverflowError:
177+
except (OverflowError, ValueError):
178178
intvalue = None
179179
if value == intvalue:
180180
return value.dtype.kind + toidentifier(intvalue)

functional_algorithms/floating_point_algorithms.py

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def split_veltkamp(ctx, x, C=None, scale=False):
222222
223223
"""
224224
if C is None:
225+
one = ctx.constant(1, x)
225226
C, N, invN = get_veltkamp_splitter_constants(ctx, get_largest(ctx, x))
226227
elif scale:
227228
one = ctx.constant(1, x)
@@ -382,13 +383,32 @@ def div_series(ctx, x, y):
382383
return x / y
383384

384385

386+
def _terms_add(ctx, terms, index, *operands):
387+
for i, v in enumerate(operands):
388+
if index + i < len(terms):
389+
if 1:
390+
terms[index + i] += v
391+
else:
392+
h, l = add_2sum(ctx, terms[index + i], v)
393+
terms[index + i] = h
394+
if index + i <= len(terms):
395+
terms.append(l)
396+
else:
397+
terms[index + i + 1] += l
398+
else:
399+
assert index + i == len(terms)
400+
terms.append(v)
401+
402+
385403
def _add_series_series(ctx, x, y):
386404

387405
def op(x, y):
388406
if x is None:
389407
return y
390408
if y is None:
391409
return x
410+
if ctx.parameters.get("series_uses_2sum"):
411+
return add_2sum(ctx, x, y)
392412
return x + y
393413

394414
return _binaryop_series_series(ctx, x, y, op)
@@ -401,6 +421,8 @@ def op(x, y):
401421
return -y
402422
if y is None:
403423
return x
424+
if ctx.parameters.get("series_uses_2sum"):
425+
return add_2sum(ctx, x, -y)
404426
return x - y
405427

406428
return _binaryop_series_series(ctx, x, y, op)
@@ -417,28 +439,30 @@ def _binaryop_series_series(ctx, x, y, op):
417439
assert sexp1 == sexp2, (sexp1, sexp2)
418440

419441
terms = []
420-
421442
for n in range(max(len(terms1), index1 - index2 + len(terms2))):
422443
k = n - (index1 - index2)
423444
if n < len(terms1):
424445
if k >= 0 and k < len(terms2):
425446
if swapped:
426-
terms.append(op(terms2[k], terms1[n]))
447+
r = op(terms2[k], terms1[n])
427448
else:
428-
terms.append(op(terms1[n], terms2[k]))
449+
r = op(terms1[n], terms2[k])
429450
else:
430451
if swapped:
431-
terms.append(op(None, terms1[n]))
452+
r = op(None, terms1[n])
432453
else:
433-
terms.append(op(terms1[n], None))
454+
r = op(terms1[n], None)
434455
elif k >= 0 and k < len(terms2):
435456
if swapped:
436-
terms.append(op(terms2[k], None))
457+
r = op(terms2[k], None)
437458
else:
438-
terms.append(op(None, terms2[k]))
459+
r = op(None, terms2[k])
439460
else:
440-
terms.append(ctx.constant(0, terms1[0]))
441-
461+
r = ctx.constant(0, terms1[0])
462+
if type(r) is tuple:
463+
_terms_add(ctx, terms, n, *r)
464+
else:
465+
_terms_add(ctx, terms, n, r)
442466
return ctx._series(tuple(terms), dict(unit_index=index1, scaling_exp=sexp1))
443467

444468

@@ -454,7 +478,7 @@ def add_series(ctx, x, y):
454478
return _add_series_series(ctx, x, ((0, 0), y))
455479
elif type(y) is tuple:
456480
return _add_series_series(ctx, ((0, 0), x), y)
457-
return x + y
481+
return _add_series_series(ctx, ((0, 0), x), ((0, 0), y))
458482

459483

460484
def subtract_series(ctx, x, y):
@@ -469,7 +493,7 @@ def subtract_series(ctx, x, y):
469493
return _subtract_series_series(ctx, x, ((0, 0), y))
470494
elif type(y) is tuple:
471495
return _subtract_series_series(ctx, ((0, 0), x), y)
472-
return x - y
496+
return _subtract_series_series(ctx, ((0, 0), x), ((0, 0), y))
473497

474498

475499
def mul_series_dekker(ctx, x, y, C=None):
@@ -484,22 +508,6 @@ def mul_series_dekker(ctx, x, y, C=None):
484508
x = ctx._get_series_operands(x)
485509
y = ctx._get_series_operands(y)
486510

487-
def terms_add(terms, index, *operands):
488-
for i, v in enumerate(operands):
489-
if index + i < len(terms):
490-
if 1:
491-
terms[index + i] += v
492-
else:
493-
h, l = add_2sum(ctx, terms[index + i], v)
494-
terms[index + i] = h
495-
if index + i <= len(terms):
496-
terms.append(l)
497-
else:
498-
terms[index + i + 1] += l
499-
else:
500-
assert index + i == len(terms)
501-
terms.append(v)
502-
503511
offset = 10000
504512

505513
if type(x) is tuple:
@@ -516,9 +524,9 @@ def terms_add(terms, index, *operands):
516524
for i, x_ in enumerate(x[1:]):
517525
for j, y_ in enumerate(y[1:]):
518526
if i + j >= offset:
519-
terms_add(terms, i + j, x_ * y_)
527+
_terms_add(ctx, terms, i + j, x_ * y_)
520528
else:
521-
terms_add(terms, i + j, *mul_dekker(ctx, x_, y_, C=C))
529+
_terms_add(ctx, terms, i + j, *mul_dekker(ctx, x_, y_, C=C))
522530
return ctx._series(tuple(terms), dict(unit_index=x[0][0] + y[0][0], scaling_exp=x[0][1]))
523531
else:
524532
# (x1, x2, ...) * y
@@ -529,17 +537,17 @@ def terms_add(terms, index, *operands):
529537
terms = []
530538
for i, x_ in enumerate(x[1:]):
531539
if i >= offset:
532-
terms_add(terms, i, x_ * y)
540+
_terms_add(ctx, terms, i, x_ * y)
533541
else:
534-
terms_add(terms, i, *mul_dekker(ctx, x_, y, C=C))
542+
_terms_add(ctx, terms, i, *mul_dekker(ctx, x_, y, C=C))
535543
return ctx._series(tuple(terms), dict(unit_index=x[0][0], scaling_exp=x[0][1]))
536544
elif type(y) is tuple:
537545
terms = []
538546
for i, y_ in enumerate(y[1:]):
539547
if i >= offset:
540-
terms_add(terms, i, x * y_)
548+
_terms_add(ctx, terms, i, x * y_)
541549
else:
542-
terms_add(terms, i, *mul_dekker(ctx, x, y_, C=C))
550+
_terms_add(ctx, terms, i, *mul_dekker(ctx, x, y_, C=C))
543551
return ctx._series(tuple(terms), dict(unit_index=y[0][0], scaling_exp=y[0][1]))
544552
return ctx._series(mul_dekker(ctx, x, y, C=C), dict(unit_index=0, scaling_exp=0))
545553

@@ -1356,7 +1364,20 @@ def sine_taylor(ctx, x, order=7, split=False):
13561364
C, f = [x], 1
13571365
for i in range(3, order + 1, 2):
13581366
f *= -i * (i - 1)
1359-
C.append(ctx.fma(x, ctx.constant(1 / f, x), zero))
1367+
if i >= 5:
1368+
C.append(ctx.fma(x, ctx.constant(1 / f, x), zero))
1369+
else:
1370+
# The following is required for float16 and float32 when f is
1371+
# small and early evaluation of 1/f leads to accuracy loss.
1372+
# For float64, there is a very minor accuracy loss.
1373+
#
1374+
# fh, fl = split_veltkamp(f)
1375+
# 1 / f = 1 / (fh + fl) = 1 / fh - fl / fh ** 2 + fl ** 2 / fh ** 3 - ...
1376+
# = [a = fl/fh] = 1 / fh - a / fh + a ** 2 / fh - ...
1377+
# ~ fma(-a, 1 / fh, 1 / fh)
1378+
fh, fl = split_veltkamp(ctx, ctx.constant(f, x))
1379+
a = fl / fh
1380+
C.append(ctx.fma(-a, x / fh, x / fh))
13601381
# Horner's scheme is most accurate
13611382
xx = ctx.fma(x, x, zero)
13621383
return fast_polynomial(ctx, xx, C, reverse=False, scheme=[None, horner_scheme, estrin_dac_scheme, canonical_scheme][1])

functional_algorithms/targets/numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def jit_decor(func):
297297
graph = ctx.trace(func, dtype)
298298
graph2 = graph.rewrite(
299299
fa.rewrite.ReplaceSeries(),
300-
fa.rewrite.ReplaceFma(backend=rewrite_parameters.get("fma_backend", "upcast")),
300+
fa.rewrite.ReplaceFma(backend=rewrite_parameters.get("fma_backend", "native")),
301301
this_module,
302302
fa.rewrite.RewriteWithParameters(**rewrite_parameters),
303303
)

functional_algorithms/tests/test_floating_point_algorithms.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,9 @@ def test_sine_pade(dtype):
802802
show_ulp(ulp)
803803

804804

805-
@pytest.mark.parametrize("func,fma", [("sin", "upcast"), ("sin", "mul_add"), ("sin", "native"), ("numpy.sin", None)])
805+
@pytest.mark.parametrize(
806+
"func,fma", [("sin", "upcast"), ("sin", "mul_add"), ("sin_dekker", "native"), ("sin", "native"), ("numpy.sin", None)]
807+
)
806808
def test_sine_taylor(dtype, func, fma):
807809
import mpmath
808810
from collections import defaultdict
@@ -817,14 +819,26 @@ def test_sine_taylor(dtype, func, fma):
817819
mpctx = mpmath.mp
818820
for order in [optimal_order, 1, 3, 5, 7, 9, 11, 13, 17, 19][:1]:
819821

820-
@fa.targets.numpy.jit(
821-
paths=[fpa],
822-
dtype=dtype,
823-
debug=(1.5 if size <= 10 else 0),
824-
rewrite_parameters=dict(optimize_cast=False, fma_backend=fma),
825-
)
826-
def sin_func(ctx, x):
827-
return fpa.sine_taylor(ctx, x, order=order, split=False)
822+
if func == "sin_dekker":
823+
824+
@fa.targets.numpy.jit(
825+
paths=[fpa],
826+
dtype=dtype,
827+
debug=(1.5 if size <= 10 else 0),
828+
)
829+
def sin_dekker_func(ctx, x):
830+
return fpa.sine_taylor_dekker(ctx, x, order=order)
831+
832+
elif func == "sin":
833+
834+
@fa.targets.numpy.jit(
835+
paths=[fpa],
836+
dtype=dtype,
837+
debug=(1.5 if size <= 10 else 0),
838+
rewrite_parameters=dict(optimize_cast=False, fma_backend=fma, series_uses_2sum=True),
839+
)
840+
def sin_func(ctx, x):
841+
return fpa.sine_taylor(ctx, x, order=order, split=False)
828842

829843
ulp = defaultdict(int)
830844
for x in samples:
@@ -833,6 +847,8 @@ def sin_func(ctx, x):
833847
sn = numpy.sin(x)
834848
elif func == "sin":
835849
sn = sin_func(x)
850+
elif func == "sin_dekker":
851+
sn = sin_dekker_func(x)
836852
else:
837853
assert 0, func # not implemented
838854
"""

functional_algorithms/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2213,6 +2213,10 @@ def mpf2multiword(dtype, x, p=None, max_length=None, sexp=0):
22132213
class NumpyContext:
22142214
"""A light-weight context for evaluating select with numpy inputs."""
22152215

2216+
@property
2217+
def parameters(self):
2218+
return {}
2219+
22162220
def select(self, cond, x, y):
22172221
assert isinstance(cond, (bool, numpy.bool_))
22182222
return x if cond else y

0 commit comments

Comments
 (0)