diff --git a/sage/group_prover.sage b/sage/group_prover.sage index 4f8fce966..706b859b2 100644 --- a/sage/group_prover.sage +++ b/sage/group_prover.sage @@ -177,6 +177,30 @@ class constraints: def __repr__(self): return "%s" % self +def normalize_factor(p): + """Normalizes the sign of primitive polynomials (as returned by factor()) + + This function ensures that the polynomial has a positive leading coefficient. + + This is necessary because recent sage versions (starting with v9.3 or v9.4, + we don't know) are inconsistent about the placement of the minus sign in + polynomial factorizations: + ``` + sage: R. = PolynomialRing(QQ,8,order='invlex') + sage: R((-2 * (bx - ax)) ^ 1).factor() + (-2) * (bx - ax) + sage: R((-2 * (bx - ax)) ^ 2).factor() + (4) * (-bx + ax)^2 + sage: R((-2 * (bx - ax)) ^ 3).factor() + (8) * (-bx + ax)^3 + ``` + """ + # Assert p is not 0 and that its non-zero coeffients are coprime. + # (We could just work with the primitive part p/p.content() but we want to be + # aware if factor() does not return a primitive part in future sage versions.) + assert p.content() == 1 + # Ensure that the first non-zero coefficient is positive. + return p if p.lc() > 0 else -p def conflicts(R, con): """Check whether any of the passed non-zero assumptions is implied by the zero assumptions""" @@ -204,10 +228,10 @@ def get_nonzero_set(R, assume): nonzero = set() for nz in map(numerator, assume.nonzero): for (f,n) in nz.factor(): - nonzero.add(f) + nonzero.add(normalize_factor(f)) rnz = zero.reduce(nz) for (f,n) in rnz.factor(): - nonzero.add(f) + nonzero.add(normalize_factor(f)) return nonzero @@ -222,27 +246,27 @@ def prove_nonzero(R, exprs, assume): return (False, [exprs[expr]]) allexprs = reduce(lambda a,b: numerator(a)*numerator(b), exprs, 1) for (f, n) in allexprs.factor(): - if f not in nonzero: + if normalize_factor(f) not in nonzero: ok = False if ok: return (True, None) ok = True for (f, n) in zero.reduce(allexprs).factor(): - if f not in nonzero: + if normalize_factor(f) not in nonzero: ok = False if ok: return (True, None) ok = True for expr in exprs: for (f,n) in numerator(expr).factor(): - if f not in nonzero: + if normalize_factor(f) not in nonzero: ok = False if ok: return (True, None) ok = True for expr in exprs: for (f,n) in zero.reduce(numerator(expr)).factor(): - if f not in nonzero: + if normalize_factor(f) not in nonzero: expl.add(exprs[expr]) if expl: return (False, list(expl)) @@ -279,8 +303,8 @@ def describe_extra(R, assume, assumeExtra): if base not in zero: add = [] for (f, n) in numerator(base).factor(): - if f not in nonzero: - add += ["%s" % f] + if normalize_factor(f) not in nonzero: + add += ["%s" % normalize_factor(f)] if add: ret.add((" * ".join(add)) + " = 0 [%s]" % assumeExtra.zero[base]) # Iterate over the extra nonzero expressions @@ -288,8 +312,8 @@ def describe_extra(R, assume, assumeExtra): nzr = zeroextra.reduce(numerator(nz)) if nzr not in zeroextra: for (f,n) in nzr.factor(): - if zeroextra.reduce(f) not in nonzero: - ret.add("%s != 0" % zeroextra.reduce(f)) + if normalize_factor(zeroextra.reduce(f)) not in nonzero: + ret.add("%s != 0" % normalize_factor(zeroextra.reduce(f))) return ", ".join(x for x in ret)