Source code for diofant.printing.codeprinter

from ..core import Add, Basic, Lambda, Mul, Pow, S, Symbol, sympify
from ..core.compatibility import default_sort_key
from ..core.mul import _keep_coeff
from ..core.relational import Relational
from .precedence import precedence
from .str import StrPrinter


[docs]class AssignmentError(Exception): """Raised if an assignment variable for a loop is missing.""" pass
class Assignment(Relational): """ Represents variable assignment for code generation. Parameters ========== lhs : Expr Diofant object representing the lhs of the expression. These should be singular objects, such as one would use in writing code. Notable types include Symbol, MatrixSymbol, MatrixElement, and Indexed. Types that subclass these types are also supported. rhs : Expr Diofant object representing the rhs of the expression. This can be any type, provided its shape corresponds to that of the lhs. For example, a Matrix type can be assigned to MatrixSymbol, but not to Symbol, as the dimensions will not align. Examples ======== >>> Assignment(x, y) x := y >>> Assignment(x, 0) x := 0 >>> A = MatrixSymbol('A', 1, 3) >>> mat = Matrix([x, y, z]).T >>> Assignment(A, mat) A := Matrix([[x, y, z]]) >>> Assignment(A[0, 1], x) A[0, 1] := x """ rel_op = ':=' def __new__(cls, lhs, rhs=0, **assumptions): from ..matrices.expressions.matexpr import MatrixElement, MatrixSymbol from ..tensor import Indexed lhs = sympify(lhs, strict=True) rhs = sympify(rhs, strict=True) # Tuple of things that can be on the lhs of an assignment assignable = (Symbol, MatrixSymbol, MatrixElement, Indexed) if not isinstance(lhs, assignable): raise TypeError("Cannot assign to lhs of type %s." % type(lhs)) # Indexed types implement shape, but don't define it until later. This # causes issues in assignment validation. For now, matrices are defined # as anything with a shape that is not an Indexed lhs_is_mat = hasattr(lhs, 'shape') and not isinstance(lhs, Indexed) rhs_is_mat = hasattr(rhs, 'shape') and not isinstance(rhs, Indexed) # If lhs and rhs have same structure, then this assignment is ok if lhs_is_mat: if not rhs_is_mat: raise ValueError("Cannot assign a scalar to a matrix.") elif lhs.shape != rhs.shape: raise ValueError("Dimensions of lhs and rhs don't align.") elif rhs_is_mat and not lhs_is_mat: raise ValueError("Cannot assign a matrix to a scalar.") return Relational.__new__(cls, lhs, rhs, **assumptions)
[docs]class CodePrinter(StrPrinter): """The base class for code-printing subclasses.""" _operators = { 'and': '&&', 'or': '||', 'not': '!', } _default_settings = {'order': None, 'full_prec': 'auto', 'error_on_reserved': False, 'reserved_word_suffix': '_'} def __init__(self, settings=None): super().__init__(settings=settings) self.reserved_words = set() def doprint(self, expr, assign_to=None): """ Print the expression as code. Parameters ========== expr : Expression The expression to be printed. assign_to : Symbol, MatrixSymbol, or string (optional) If provided, the printed code will set the expression to a variable with name ``assign_to``. """ from ..matrices import MatrixSymbol if isinstance(assign_to, str): if expr.is_Matrix: assign_to = MatrixSymbol(assign_to, *expr.shape) else: assign_to = Symbol(assign_to) elif not isinstance(assign_to, (Basic, type(None))): raise TypeError("{0} cannot assign to object of type {1}".format( type(self).__name__, type(assign_to))) if assign_to: expr = Assignment(assign_to, expr) else: # non-strict sympify is not enough b/c it errors on iterables expr = sympify(expr) # keep a set of expressions that are not strictly translatable to Code # and number constants that must be declared and initialized self._not_supported = set() self._number_symbols = set() lines = self._print(expr).splitlines() # format the output if self._settings["human"]: frontlines = [] if len(self._not_supported) > 0: frontlines.append(self._get_comment( "Not supported in {0}:".format(self.language))) for expr in sorted(self._not_supported, key=str): frontlines.append(self._get_comment(type(expr).__name__)) for name, value in sorted(self._number_symbols, key=str): frontlines.append(self._declare_number_const(name, value)) lines = frontlines + lines lines = self._format_code(lines) result = "\n".join(lines) else: lines = self._format_code(lines) result = (self._number_symbols, self._not_supported, "\n".join(lines)) del self._not_supported del self._number_symbols return result def _doprint_loops(self, expr, assign_to=None): # Here we print an expression that contains Indexed objects, they # correspond to arrays in the generated code. The low-level implementation # involves looping over array elements and possibly storing results in temporary # variables or accumulate it in the assign_to object. assert self._settings['contract'] from ..tensor import get_contraction_structure # Setup loops over non-dummy indices -- all terms need these indices = self._get_expression_indices(expr, assign_to) # Setup loops over dummy indices -- each term needs separate treatment dummies = get_contraction_structure(expr) openloop, closeloop = self._get_loop_opening_ending(indices) # terms with no summations first if None in dummies: text = StrPrinter.doprint(self, Add(*dummies[None])) else: # If all terms have summations we must initialize array to Zero text = StrPrinter.doprint(self, 0) # skip redundant assignments (where lhs == rhs) lhs_printed = self._print(assign_to) lines = [] if text != lhs_printed: lines.extend(openloop) assert assign_to is not None text = self._get_statement("%s = %s" % (lhs_printed, text)) lines.append(text) lines.extend(closeloop) # then terms with summations for d in dummies: if isinstance(d, tuple): indices = self._sort_optimized(d, expr) openloop_d, closeloop_d = self._get_loop_opening_ending( indices) for term in dummies[d]: if term in dummies and not ([list(f) for f in dummies[term]] == [[None] for f in dummies[term]]): # pragma: no cover # If one factor in the term has it's own internal # contractions, those must be computed first. # (temporary variables?) raise NotImplementedError( "FIXME: no support for contractions in factor yet") else: # We need the lhs expression as an accumulator for # the loops, i.e # # for (int d=0; d < dim; d++){ # lhs[] = lhs[] + term[][d] # } ^.................. the accumulator # # We check if the expression already contains the # lhs, and raise an exception if it does, as that # syntax is currently undefined. FIXME: What would be # a good interpretation? assert assign_to is not None assert not term.has(assign_to) lines.extend(openloop) lines.extend(openloop_d) text = "%s = %s" % (lhs_printed, StrPrinter.doprint( self, assign_to + term)) lines.append(self._get_statement(text)) lines.extend(closeloop_d) lines.extend(closeloop) return "\n".join(lines) def _get_expression_indices(self, expr, assign_to): from ..tensor import get_indices rinds, junk = get_indices(expr) linds, junk = get_indices(assign_to) # support broadcast of scalar if linds and not rinds: rinds = linds if rinds != linds: raise ValueError("lhs indices must match non-dummy" " rhs indices in %s" % expr) return self._sort_optimized(rinds, assign_to) def _sort_optimized(self, indices, expr): from ..tensor import Indexed if not indices: return [] # determine optimized loop order by giving a score to each index # the index with the highest score are put in the innermost loop. score_table = {} for i in indices: score_table[i] = 0 arrays = expr.atoms(Indexed) for arr in arrays: for p, ind in enumerate(arr.indices): try: score_table[ind] += self._rate_index_position(p) except KeyError: pass return sorted(indices, key=lambda x: score_table[x]) def _rate_index_position(self, p): """function to calculate score based on position among indices This method is used to sort loops in an optimized order, see CodePrinter._sort_optimized() """ raise NotImplementedError("This function must be implemented by " "subclass of CodePrinter.") # pragma: no cover def _get_statement(self, codestring): """Formats a codestring with the proper line ending.""" raise NotImplementedError("This function must be implemented by " "subclass of CodePrinter.") # pragma: no cover def _get_comment(self, text): """Formats a text string as a comment.""" raise NotImplementedError("This function must be implemented by " "subclass of CodePrinter.") # pragma: no cover def _declare_number_const(self, name, value): """Declare a numeric constant at the top of a function.""" raise NotImplementedError("This function must be implemented by " "subclass of CodePrinter.") # pragma: no cover def _format_code(self, lines): """Take in a list of lines of code, and format them accordingly. This may include indenting, wrapping long lines, etc... """ raise NotImplementedError("This function must be implemented by " "subclass of CodePrinter.") # pragma: no cover def _get_loop_opening_ending(self, indices): """Returns a tuple (open_lines, close_lines) containing lists of codelines """ raise NotImplementedError("This function must be implemented by " "subclass of CodePrinter.") # pragma: no cover def _print_Assignment(self, expr): from ..functions import Piecewise from ..matrices import MatrixSymbol from ..tensor.indexed import IndexedBase lhs = expr.lhs rhs = expr.rhs # We special case assignments that take multiple lines if isinstance(expr.rhs, Piecewise): # Here we modify Piecewise so each expression is now # an Assignment, and then continue on the print. expressions = [] conditions = [] for (e, c) in rhs.args: expressions.append(Assignment(lhs, e)) conditions.append(c) temp = Piecewise(*zip(expressions, conditions)) return self._print(temp) elif isinstance(lhs, MatrixSymbol): # Here we form an Assignment for each element in the array, # printing each one. lines = [] for (i, j) in self._traverse_matrix_indices(lhs): temp = Assignment(lhs[i, j], rhs[i, j]) code0 = self._print(temp) lines.append(code0) return "\n".join(lines) elif self._settings["contract"] and (lhs.has(IndexedBase) or rhs.has(IndexedBase)): # Here we check if there is looping to be done, and if so # print the required loops. return self._doprint_loops(rhs, lhs) else: lhs_code = self._print(lhs) rhs_code = self._print(rhs) return self._get_statement("%s = %s" % (lhs_code, rhs_code)) def _print_Symbol(self, expr): name = super()._print_Symbol(expr) if name in self.reserved_words: if self._settings['error_on_reserved']: msg = ('This expression includes the symbol "{}" which is a ' 'reserved keyword in this language.') raise ValueError(msg.format(name)) return name + self._settings['reserved_word_suffix'] else: return name def _print_Function(self, expr): if expr.func.__name__ in self.known_functions: cond_func = self.known_functions[expr.func.__name__] func = None if isinstance(cond_func, str): func = cond_func else: for cond, func in cond_func: if cond(*expr.args): break else: return self._print_not_supported(expr) try: return func(*[self.parenthesize(item, 0) for item in expr.args]) except TypeError: return "%s(%s)" % (func, self.stringify(expr.args, ", ")) elif hasattr(expr, '_imp_') and isinstance(expr._imp_, Lambda): # inlined function return self._print(expr._imp_(*expr.args)) else: return self._print_not_supported(expr) def _print_NumberSymbol(self, expr): # A Number symbol that is not implemented here or with _printmethod # is registered and evaluated self._number_symbols.add((expr, self._print(expr.evalf(self._settings["precision"])))) return str(expr) def _print_Dummy(self, expr): # dummies must be printed as unique symbols return "%s_%i" % (expr.name, expr.dummy_index) # Dummy def _print_Catalan(self, expr): return self._print_NumberSymbol(expr) def _print_EulerGamma(self, expr): return self._print_NumberSymbol(expr) def _print_GoldenRatio(self, expr): return self._print_NumberSymbol(expr) def _print_Exp1(self, expr): return self._print_NumberSymbol(expr) def _print_Pi(self, expr): return self._print_NumberSymbol(expr) def _print_And(self, expr): PREC = precedence(expr) return (" %s " % self._operators['and']).join(self.parenthesize(a, PREC) for a in sorted(expr.args, key=default_sort_key)) def _print_Or(self, expr): PREC = precedence(expr) return (" %s " % self._operators['or']).join(self.parenthesize(a, PREC) for a in sorted(expr.args, key=default_sort_key)) def _print_Xor(self, expr): if self._operators.get('xor') is None: return self._print_not_supported(expr) PREC = precedence(expr) return (" %s " % self._operators['xor']).join(self.parenthesize(a, PREC) for a in expr.args) def _print_Equivalent(self, expr): if self._operators.get('equivalent') is None: return self._print_not_supported(expr) PREC = precedence(expr) return (" %s " % self._operators['equivalent']).join(self.parenthesize(a, PREC) for a in expr.args) def _print_Not(self, expr): PREC = precedence(expr) return self._operators['not'] + self.parenthesize(expr.args[0], PREC) def _print_Mul(self, expr): prec = precedence(expr) c, e = expr.as_coeff_Mul() if c < 0: expr = _keep_coeff(-c, e) sign = "-" else: sign = "" a = [] # items in the numerator b = [] # items that are in the denominator (if any) if self.order != 'none': args = expr.as_ordered_factors() else: # use make_args in case expr was something like -x -> x args = Mul.make_args(expr) # Gather args for numerator/denominator for item in args: if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative: if item.exp != -1: b.append(Pow(item.base, -item.exp, evaluate=False)) else: b.append(Pow(item.base, -item.exp)) else: a.append(item) a = a or [S.One] a_str = [self.parenthesize(x, prec) for x in a] b_str = [self.parenthesize(x, prec) for x in b] if len(b) == 0: return sign + '*'.join(a_str) elif len(b) == 1: return sign + '*'.join(a_str) + "/" + b_str[0] else: return sign + '*'.join(a_str) + "/(%s)" % '*'.join(b_str) def _print_not_supported(self, expr): self._not_supported.add(expr) return self.emptyPrinter(expr) # The following can not be simply translated into C or Fortran _print_Basic = _print_not_supported _print_ComplexInfinity = _print_not_supported _print_Derivative = _print_not_supported _print_dict = _print_not_supported _print_ExprCondPair = _print_not_supported _print_GeometryEntity = _print_not_supported _print_Infinity = _print_not_supported _print_Integral = _print_not_supported _print_Interval = _print_not_supported _print_Limit = _print_not_supported _print_list = _print_not_supported _print_Matrix = _print_not_supported _print_ImmutableMatrix = _print_not_supported _print_MutableDenseMatrix = _print_not_supported _print_MatrixBase = _print_not_supported _print_DeferredVector = _print_not_supported _print_NaN = _print_not_supported _print_NegativeInfinity = _print_not_supported _print_Normal = _print_not_supported _print_Order = _print_not_supported _print_RootOf = _print_not_supported _print_RootsOf = _print_not_supported _print_RootSum = _print_not_supported _print_SparseMatrix = _print_not_supported _print_tuple = _print_not_supported _print_Wild = _print_not_supported _print_WildFunction = _print_not_supported