# Source code for diofant.polys.rootoftools

"""Implementation of RootOf class and related tools. """

from mpmath import findroot, mpc, mpf, workprec
from mpmath.libmp.libmpf import prec_to_dps

from ..core import (Add, Dummy, Expr, Float, I, Integer, Lambda, Rational,
cacheit, symbols, sympify)
from ..core.compatibility import ordered
from ..core.evaluate import global_evaluate
from ..core.function import AppliedUndef
from ..domains import QQ
from ..functions import root as _root
from ..functions import sign
from ..logic import false
from ..utilities import lambdify, sift
from .polyerrors import (DomainError, GeneratorsNeeded,
MultivariatePolynomialError, PolynomialError)
from .polyfuncs import symmetrize, viete
from .polyroots import (preprocess_roots, roots, roots_binomial, roots_cubic,
from .polytools import Poly, PurePoly, factor
from .rationaltools import together
from .rootisolation import (dup_isolate_complex_roots_sqf,
dup_isolate_real_roots_sqf)

__all__ = 'RootOf', 'RootSum'

_reals_cache = {}
_complexes_cache = {}

[docs]class RootOf(Expr):
"""
Represents k-th root of a univariate polynomial.

The ordering used for indexing takes real roots to come before complex
ones, sort complex roots by real part, then by imaginary part and
finally takes complex conjugate pairs of roots to be adjacent.

Parameters
==========

f : Expr
Univariate polynomial expression.
x : Symbol or Integer
Polynomial variable or the index of the root.
index : Integer or None, optional
Index of the root.  If None (default), parameter x is
Explicitly solve linear or quadratic polynomial
equation (enabled by default).
expand : bool, optional
Expand polynomial, enabled default.
evaluate : bool or None, optional
Control automatic evaluation.

Examples
========

>>> expand_func(RootOf(x**3 + I*x + 2, 0))
RootOf(x**6 + 4*x**3 + x**2 + 4, 1)

"""

is_commutative = True

def __new__(cls, f, x, index=None, radicals=True, expand=True, evaluate=None):
"""Construct a new RootOf object for k-th root of f."""
x = sympify(x)

if index is None and x.is_Integer:
x, index = None, x
else:
index = sympify(index)

if index is not None and index.is_Integer:
index = int(index)
else:
raise ValueError("expected an integer root index, got %s" % index)

poly = PurePoly(f, x, greedy=False, expand=expand)

if not poly.is_univariate:
raise PolynomialError("only univariate polynomials are allowed")

degree = poly.degree()
dom = poly.domain

if degree <= 0:
raise PolynomialError("can't construct RootOf object for %s" % f)

if index < -degree or index >= degree:
raise IndexError("root index out of [%d, %d] range, got %d" %
(-degree, degree - 1, index))
elif index < 0:
index += degree

if not dom.is_IntegerRing and poly.LC().is_nonzero is False:
raise NotImplementedError("sorted roots not supported over %s" % dom)

if evaluate is None:
evaluate = global_evaluate[0]

if not evaluate:
obj = Expr.__new__(cls)

obj.poly = poly
obj.index = index

return obj

if not dom.is_Exact:
poly = poly.to_exact()

if roots is not None:
return roots[index]

coeff, poly = preprocess_roots(poly)

if poly.domain.is_IntegerRing or poly.domain.is_AlgebraicField:
root = cls._indexed_root(poly, index)
else:
root = poly, index

@classmethod
def _new(cls, poly, index):
"""Construct new RootOf object from raw data."""
obj = Expr.__new__(cls)

obj.poly = PurePoly(poly)
obj.index = index

return obj

def _hashable_content(self):
return self.poly, self.index

@property
def expr(self):
return self.poly.as_expr()

@property
def args(self):
return self.expr, self.poly.gen, Integer(self.index)

@property
def free_symbols(self):
return self.poly.free_symbols

def _eval_expand_func(self, **hints):
poly = self.poly
index = self.index
if poly.domain.is_AlgebraicField:
minpoly, x = poly, poly.gen
while minpoly.domain.is_AlgebraicField:
_, _, minpoly = minpoly.sqf_norm()
minpoly = minpoly.retract()
for idx, r in enumerate(minpoly.all_roots()):  # pragma: no branch
if poly.as_expr().evalf(2, subs={x: r}, chop=True) == 0:
index -= 1
if index == -1:
break
poly, index = minpoly, idx
return self.func(poly.as_expr(), poly.gen, index)

def _eval_is_real(self):
try:
return int(self.index) < int(self.poly.count_roots())
except DomainError:
pass
_eval_is_extended_real = _eval_is_real

def _eval_is_complex(self):
if all(_.is_complex for _ in self.poly.coeffs()):
return True

def _eval_is_imaginary(self):
if self.is_real:
return False
elif self.is_real is False:
ivl = self.interval
return ivl.ax*ivl.bx <= 0

def _eval_is_algebraic(self):
if all(_.is_algebraic for _ in self.poly.coeffs()):
return True

def _eval_power(self, expt):
p = self.poly
if p.degree() == expt and p.length() == 2 and p.TC():
return -p.TC()/p.LC()
elif ((p.domain.is_IntegerRing or p.domain.is_AlgebraicField) and
isinstance(expt, Integer) and (expt < 0 or expt >= p.degree())):
b = Poly(p.gen**abs(expt), p.gen, domain=p.domain)
if expt < 0:
b = b.invert(p)
x = self.doit()
return sum(c*x**n for (n,), c in b.rem(p).terms())

def _eval_rewrite_as_Pow(self, e, x, i):
p = self.poly
n = p.degree()
if n == 3:
return roots_cubic(p)[i]
elif n == 4:
return roots_quartic(p)[i]

def _eval_conjugate(self):
if self.is_real:
return self
elif self.poly.domain.is_IntegerRing:
nreals = self.poly.count_roots()
ci = self.index + 2*((self.index - nreals + 1) % 2) - 1
return self._new(self.poly, ci)

@property
def is_number(self):
return not self.free_symbols

[docs]    @classmethod
"""Get real roots of a polynomial."""

[docs]    @classmethod
"""Get real and complex roots of a polynomial."""

@classmethod
def _get_reals_sqf(cls, factor):
"""Compute real root isolating intervals for a square-free polynomial."""
if factor not in _reals_cache:
reals = dup_isolate_real_roots_sqf(factor.rep.to_dense(), factor.domain, blackbox=True)
if not reals:
_reals_cache[factor] = []
return reals
return _reals_cache[factor]

@classmethod
def _get_complexes_sqf(cls, factor):
"""Compute complex root isolating intervals for a square-free polynomial."""
if factor not in _complexes_cache:
complexes = dup_isolate_complex_roots_sqf(factor.rep.to_dense(), factor.domain, blackbox=True)
if not complexes:
_complexes_cache[factor] = []
return complexes
return _complexes_cache[factor]

@classmethod
def _get_reals(cls, factors):
"""Compute real root isolating intervals for a list of factors."""
reals = []

for factor, k in factors:
real_part = cls._get_reals_sqf(factor)
reals.extend([(root, factor, k) for root in real_part])

return reals

@classmethod
def _get_complexes(cls, factors):
"""Compute complex root isolating intervals for a list of factors."""
complexes = []

for factor, k in factors:
complex_part = cls._get_complexes_sqf(factor)
complexes.extend([(root, factor, k) for root in complex_part])

return complexes

@classmethod
def _reals_sorted(cls, reals):
"""Make real isolating intervals disjoint and sort roots."""
factors = list({f for _, f, _ in reals})
if len(factors) == 1 and factors[0] in _reals_cache:
return reals

cache = {}

for i, (u, f, k) in enumerate(reals):
for j, (v, g, m) in enumerate(reals[i + 1:]):
while not u.is_disjoint(v):
u, v = u.refine(), v.refine()
reals[i + j + 1] = (v, g, m)

reals[i] = (u, f, k)

reals = sorted(reals, key=lambda r: r[0].a)

for root, factor, _ in reals:
if factor in cache:
cache[factor].append(root)
else:
cache[factor] = [root]

for factor, roots in cache.items():
_reals_cache[factor] = roots

return reals

@classmethod
def _complexes_sorted(cls, complexes):
"""Make complex isolating intervals disjoint and sort roots."""
factors = list({f for _, f, _ in complexes})
if len(factors) == 1 and factors[0] in _complexes_cache:
return complexes

cache = {}

for i, (u, f, k) in enumerate(complexes):
for j, (v, g, m) in enumerate(complexes[i + 1:]):
while not u.is_disjoint(v, check_re_refinement=True):
u, v = u.refine(), v.refine()
complexes[i + j + 1] = (v, g, m)

complexes[i] = (u, f, k)

complexes = sorted(complexes,
key=lambda r: (max(_[0].ax for _ in complexes
if not _[0].is_disjoint(r[0], re_disjoint=True)),
(r[0] if r[0].conj else r[0].conjugate()).ay))

for root, factor, _ in complexes:
if factor in cache:
cache[factor].append(root)
else:
cache[factor] = [root]

for factor, roots in cache.items():
_complexes_cache[factor] = roots

return complexes

@classmethod
def _reals_index(cls, reals, index):
"""Map initial real root index to an index in a factor where the root belongs."""
i = 0

for j, (_, factor, k) in enumerate(reals):  # pragma: no branch
if index < i + k:
poly, index = factor, 0

for _, factor, _ in reals[:j]:
if factor == poly:
index += 1

return poly, index
else:
i += k

@classmethod
def _complexes_index(cls, complexes, index):
"""Map initial complex root index to an index in a factor where the root belongs."""
index, i = index, 0

for j, (_, factor, k) in enumerate(complexes):  # pragma: no branch
if index < i + k:
poly, index = factor, 0

for _, factor, _ in complexes[:j]:
if factor == poly:
index += 1

index += poly.count_roots()

return poly, index
else:
i += k

@classmethod
def _count_roots(cls, roots):
"""Count the number of real or complex roots including multiplicities."""
return sum(k for _, _, k in roots)

@classmethod
def _refine_imaginaries(cls, complexes):
sifted = sift(complexes, lambda c: c[1])
complexes = []
for f in ordered(sifted):
nimag = f.compose(PurePoly(I*f.gen, f.gen,
domain=f.domain.algebraic_field(I))).count_roots()
potential_imag = list(range(len(sifted[f])))
while len(potential_imag) > nimag:
for i in list(potential_imag):
u, f, k = sifted[f][i]
if u.ax*u.bx > 0:
potential_imag.remove(i)
else:
sifted[f][i] = u.refine(), f, k
complexes.extend(sifted[f])
return complexes

@classmethod
def _indexed_root(cls, poly, index):
"""Get a root of a composite polynomial by index."""
_, factors = poly.factor_list()

reals = cls._get_reals(factors)
reals_count = cls._count_roots(reals)

if index < reals_count:
reals = cls._reals_sorted(reals)
return cls._reals_index(reals, index)
else:
complexes = cls._get_complexes(factors)
complexes = cls._refine_imaginaries(complexes)
complexes = cls._complexes_sorted(complexes)
return cls._complexes_index(complexes, index - reals_count)

@classmethod
def _real_roots(cls, poly):
"""Get real roots of a composite polynomial."""
_, factors = poly.factor_list()

reals = cls._get_reals(factors)
reals = cls._reals_sorted(reals)
reals_count = cls._count_roots(reals)

roots = []

for index in range(reals_count):
roots.append(cls._reals_index(reals, index))

return roots

@classmethod
def _all_roots(cls, poly):
"""Get real and complex roots of a composite polynomial."""

if not (poly.domain.is_IntegerRing or poly.domain.is_AlgebraicField):
return [(poly, i) for i in range(poly.degree())]

_, factors = poly.factor_list()

reals = cls._get_reals(factors)
reals = cls._reals_sorted(reals)
reals_count = cls._count_roots(reals)

roots = []

for index in range(reals_count):
roots.append(cls._reals_index(reals, index))

complexes = cls._get_complexes(factors)
complexes = cls._refine_imaginaries(complexes)
complexes = cls._complexes_sorted(complexes)
complexes_count = cls._count_roots(complexes)

for index in range(complexes_count):
roots.append(cls._complexes_index(complexes, index))

return roots

@classmethod
@cacheit
"""Compute roots in linear, quadratic and binomial cases."""
n = poly.degree()

if n == 1:
return roots_linear(poly)

return

if n == 2:
elif poly.length() == 2 and poly.coeff_monomial(1):
if not poly.free_symbols_in_domain:
return roots_binomial(poly)
elif all(sign(_) in (-1, 1) for _ in poly.coeffs()):
lc, tc = poly.LC(), poly.TC()
x, r = poly.gen, _root(abs(tc/lc), n)
poly = Poly(x**n + sign(lc*tc), x)
return [r*_ for _ in cls._roots_trivial(poly, radicals)]

@classmethod
def _preprocess_roots(cls, poly):
"""Take heroic measures to make poly compatible with RootOf."""
dom = poly.domain

if not dom.is_Exact:
poly = poly.to_exact()

coeff, poly = preprocess_roots(poly)
dom = poly.domain

if not dom.is_IntegerRing and poly.LC().is_nonzero is False:
raise NotImplementedError("sorted roots not supported over %s" % dom)

return coeff, poly

@classmethod
"""Return the root if it is trivial or a RootOf object."""
poly, index = root

if roots is not None:
return roots[index]
else:
return cls._new(poly, index)

@classmethod
"""Return postprocessed roots of specified kind."""

poly = PurePoly(poly)

coeff, poly = cls._preprocess_roots(poly)
roots = []

for root in getattr(cls, method)(poly):

return roots

@property
def interval(self):
"""Return isolation interval for the root."""
if self.is_real:
return _reals_cache[self.poly][self.index]
else:
reals_count = self.poly.count_roots()
return _complexes_cache[self.poly][self.index - reals_count]

[docs]    def refine(self):
"""Refine isolation interval for the root."""
if self.is_real:
root = _reals_cache[self.poly][self.index]
_reals_cache[self.poly][self.index] = root.refine()
else:
reals_count = self.poly.count_roots()
root = _complexes_cache[self.poly][self.index - reals_count]
_complexes_cache[self.poly][self.index - reals_count] = root.refine()

def _eval_subs(self, old, new):
if old in self.free_symbols:
return self.func(self.poly.subs({old: new}), *self.args[1:])
else:
# don't allow subs to change anything
return self

def _eval_evalf(self, prec):
"""Evaluate this complex root to the given precision."""
with workprec(prec):
g = self.poly.gen
if not g.is_Symbol:
d = Dummy('x')
func = lambdify(d, self.expr.subs({g: d}), "mpmath")
else:
func = lambdify(g, self.expr, "mpmath")

try:
interval = self.interval
except DomainError:
return super()._eval_evalf(prec)

while True:
if self.is_extended_real:
a = mpf(str(interval.a))
b = mpf(str(interval.b))
if a == b:
root = a
break
x0 = mpf(str(interval.center))
else:
ax = mpf(str(interval.ax))
bx = mpf(str(interval.bx))
ay = mpf(str(interval.ay))
by = mpf(str(interval.by))
x0 = mpc(*map(str, interval.center))
if ax == bx and ay == by:
root = x0
break

try:
root = findroot(func, x0)
# If the (real or complex) root is not in the 'interval',
# then keep refining the interval. This happens if findroot
# accidentally finds a different root outside of this
# interval because our initial estimate 'x0' was not close
# enough. It is also possible that the secant method will
# get trapped by a max/min in the interval; the root
# verification by findroot will raise a ValueError in this
# case and the interval will then be tightened -- and
# eventually the root will be found.
if self.is_extended_real:
if (a <= root <= b):
break
elif (ax <= root.real <= bx and ay <= root.imag <= by
and (interval.ay > 0 or interval.by < 0)):
break
except (ValueError, UnboundLocalError):
pass
self.refine()
interval = self.interval

return ((Float._new(root.real._mpf_, prec) if not self.is_imaginary else 0) +
I*Float._new(root.imag._mpf_, prec))

[docs]    def eval_rational(self, tol):
"""
Returns a Rational approximation to self with the tolerance tol.

The returned instance will be at most 'tol' from the exact root.

The following example first obtains Rational approximation to 1e-7
accuracy for all roots of the 4-th order Legendre polynomial, and then
evaluates it to 5 decimal digits (so all digits will be correct
including rounding):

>>> p = legendre_poly(4, x, polys=True)
>>> roots = [r.eval_rational(Rational(1, 10)**7) for r in p.real_roots()]
>>> roots = [str(r.evalf(5)) for r in roots]
>>> roots
['-0.86114', '-0.33998', '0.33998', '0.86114']

"""

if not self.is_extended_real:
raise NotImplementedError("eval_rational() only works for real polynomials so far")
interval = self.interval
while interval.b - interval.a > tol:
self.refine()
interval = self.interval
a = Rational(str(interval.a))
b = Rational(str(interval.b))
return (a + b)/2

def _eval_Eq(self, other):
# RootOf represents a Root, so if other is that root, it should set
# the expression to zero *and* it should be in the interval of the
# RootOf instance. It must also be a number that agrees with the
# is_real value of the RootOf instance.
if type(self) == type(other):
return sympify(self.__eq__(other))
if not (other.is_number and not other.has(AppliedUndef)):
return false
if not other.is_finite:
return false
z = self.expr.subs({self.expr.free_symbols.pop(): other}).is_zero
if z is False:    # all roots will make z True but we don't know
return false  # whether this is the right root if z is True
o = other.is_extended_real, other.is_imaginary
s = self.is_extended_real, self.is_imaginary
if o != s and None not in o and None not in s:
return false
i = self.interval
re, im = other.as_real_imag()
if self.is_extended_real:
if im:
return false
else:
return sympify(i.a < other and other < i.b)
return sympify((i.ax < re and re < i.bx) and (i.ay < im and im < i.by))

def _eval_derivative(self, x):
coeffs = self.poly.all_coeffs()
num = sum(c.diff(x)*self**n for n, c in enumerate(reversed(coeffs)))
den = sum(c*n*self**(n - 1) for n, c in enumerate(reversed(coeffs)))
return -num/den

[docs]class RootSum(Expr):
"""Represents a sum of all roots of a univariate polynomial."""

def __new__(cls, expr, func=None, x=None, auto=True, quadratic=False):
"""Construct a new RootSum instance carrying all roots of a polynomial."""
coeff, poly = cls._transform(expr, x)

if not poly.is_univariate:
raise MultivariatePolynomialError(
"only univariate polynomials are allowed")

if func is None:
func = Lambda(poly.gen, poly.gen)
else:
try:
is_func = func.is_Function
except AttributeError:
is_func = False

if is_func and 1 in func.nargs:
if not isinstance(func, Lambda):
func = Lambda(poly.gen, func(poly.gen))
else:
raise ValueError(
"expected a univariate function, got %s" % func)

var, expr = func.variables[0], func.expr

if coeff != 1:
expr = expr.subs({var: coeff*var})

deg = poly.degree()

if not expr.has(var):
return deg*expr

else:

if expr.is_Mul:
mul_const, expr = expr.as_independent(var)
else:
mul_const = Integer(1)

func = Lambda(var, expr)

rational = cls._is_func_rational(poly, func)
factors, terms = poly.factor_list()[1], []

for poly, k in factors:
if poly.is_linear:
term = func(roots_linear(poly)[0])
else:
if not rational or not auto:
term = cls._new(poly, func, auto)
else:
term = cls._rational_case(poly, func)

terms.append(k*term)

@classmethod
def _new(cls, poly, func, auto=True):
"""Construct new raw RootSum instance."""
obj = Expr.__new__(cls)

obj.poly = poly
obj.fun = func
obj.auto = auto

return obj

[docs]    @classmethod
def new(cls, poly, func, auto=True):
"""Construct new RootSum instance."""

rational = cls._is_func_rational(poly, func)

if not rational or not auto:
return cls._new(poly, func, auto)
else:
return cls._rational_case(poly, func)

@classmethod
def _transform(cls, expr, x):
"""Transform an expression to a polynomial."""
poly = PurePoly(expr, x, greedy=False)
return preprocess_roots(poly)

@classmethod
def _is_func_rational(cls, poly, func):
"""Check if a lambda is areational function."""
var, expr = func.variables[0], func.expr
return expr.is_rational_function(var)

@classmethod
def _rational_case(cls, poly, func):
"""Handle the rational function case."""
roots = symbols('r:%d' % poly.degree())
var, expr = func.variables[0], func.expr

f = sum(expr.subs({var: r}) for r in roots)
p, q = together(f).as_numer_denom()

domain = QQ.poly_ring(*roots)

p = p.expand()
q = q.expand()

try:
p = Poly(p, domain=domain, expand=False)
except GeneratorsNeeded:
p, p_coeff = None, (p,)
else:
p_monom, p_coeff = zip(*p.terms())

try:
q = Poly(q, domain=domain, expand=False)
except GeneratorsNeeded:
q, q_coeff = None, (q,)
else:
q_monom, q_coeff = zip(*q.terms())

coeffs, mapping = symmetrize(p_coeff + q_coeff, formal=True)
formulas, values = viete(poly, roots), []

for (sym, _), (_, val) in zip(mapping, formulas):
values.append((sym, val))

for i, (coeff, _) in enumerate(coeffs):
coeffs[i] = coeff.subs(values)

n = len(p_coeff)

p_coeff = coeffs[:n]
q_coeff = coeffs[n:]

if p is not None:
p = Poly(dict(zip(p_monom, p_coeff)), *p.gens).as_expr()
else:
p, = p_coeff

if q is not None:
q = Poly(dict(zip(q_monom, q_coeff)), *q.gens).as_expr()
else:
q, = q_coeff

return factor(p/q)

def _hashable_content(self):
return self.poly, self.fun

@property
def expr(self):
return self.poly.as_expr()

@property
def args(self):
return self.expr, self.fun, self.poly.gen

@property
def free_symbols(self):
return self.poly.free_symbols | self.fun.free_symbols

@property
def is_commutative(self):
return True

def doit(self, **hints):
_roots = roots(self.poly, multiple=True)

if len(_roots) < self.poly.degree():
return self
else:
return Add(*[self.fun(r) for r in _roots])

def _eval_evalf(self, prec):
try:
_roots = self.poly.nroots(n=prec_to_dps(prec))
except (DomainError, PolynomialError):
return self
else:
return Add(*[self.fun(r) for r in _roots])

def _eval_derivative(self, x):
var, expr = self.fun.args
func = Lambda(var, expr.diff(x))
return self.new(self.poly, func, self.auto)