Source code for daetools.code_generators.formatter

"""
***********************************************************************************
                            formatter.py
                DAE Tools: pyDAE module, www.daetools.com
                Copyright (C) Dragan Nikolic
***********************************************************************************
DAE Tools is free software; you can redistribute it and/or modify it under the
terms of the GNU General Public License version 3 as published by the Free Software
Foundation. DAE Tools is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE. See the GNU General Public License for more details.
You should have received a copy of the GNU General Public License along with the
DAE Tools software; if not, see <http://www.gnu.org/licenses/>.
************************************************************************************
"""
import sys, numpy, math, traceback
from daetools.pyDAE import *

[docs]class daeExpressionFormatter(object): def __init__(self): # Equation and condition node formatting settings # Index base in arrays: # - Modelica, gPROMS use 1 # - daetools, python, c/c++ use 0 # - FMI not relevant (variables are flattened) self.indexBase = 0 self.useFlattenedNamesForAssignedVariables = False self.IDs = {} self.indexMap = {} # Use relative names (relative to domains/parameters/variables model) or full canonical names # If we are in model root.comp1 then variables' names could be: # if useRelativeNames is True: # name = 'comp2.Var' (relative to parent comp1) # else: # name = 'root.comp1.comp2.Var' (full canonical name) self.useRelativeNames = True self.flattenIdentifiers = False self.domain = '{domain}[{index}]' self.parameter = '{parameter}({indexes})' self.parameterIndexStart = '' self.parameterIndexEnd = '' self.parameterIndexDelimiter = ',' self.variable = '{variable}({indexes})' self.variableIndexStart = '' self.variableIndexEnd = '' self.variableIndexDelimiter = ',' self.assignedVariable = '{variable}' # String format for the time derivative, ie. der(variable[1,2]) in Modelica # daetools use: variable.dt(1,2), gPROMS $variable(1,2) ... self.derivative = '{variable}.dt({indexes})' self.derivativeIndexStart = '' self.derivativeIndexEnd = '' self.derivativeIndexDelimiter = ',' self.feMatrixItem = '{value}' self.feVectorItem = '{value}' # Constants self.constant = '{value}' # External functions self.scalarExternalFunction = '{name}()' self.vectorExternalFunction = '{name}()' # Logical operators self.AND = '{leftValue} and {rightValue}' self.OR = '{leftValue} or {rightValue}' self.NOT = 'not {value}' self.EQ = '{leftValue} == {rightValue}' self.NEQ = '{leftValue} != {rightValue}' self.LT = '{leftValue} < {rightValue}' self.LTEQ = '{leftValue} <= {rightValue}' self.GT = '{leftValue} > {rightValue}' self.GTEQ = '{leftValue} >= {rightValue}' # Mathematical operators self.SIGN = '-{value}' self.PLUS = '{leftValue} + {rightValue}' self.MINUS = '{leftValue} - {rightValue}' self.MULTI = '{leftValue} * {rightValue}' self.DIVIDE = '{leftValue} / {rightValue}' self.POWER = '{leftValue} ^ {rightValue}' # Mathematical functions self.SIN = 'sin({value})' self.COS = 'cos({value})' self.TAN = 'tan({value})' self.ASIN = 'asin({value})' self.ACOS = 'acos({value})' self.ATAN = 'atan({value})' self.EXP = 'exp({value})' self.SQRT = 'sqrt({value})' self.LOG = 'log({value})' self.LOG10 = 'log10({value})' self.FLOOR = 'floor({value})' self.CEIL = 'ceil({value})' self.ABS = 'abs({value})' self.SINH = 'sinh({value})' self.COSH = 'cosh({value})' self.TANH = 'tanh({value})' self.ASINH = 'asinh({value})' self.ACOSH = 'acosh({value})' self.ATANH = 'atanh({value})' self.ERF = 'erf({value})' self.MIN = 'min({leftValue}, {rightValue})' self.MAX = 'max({leftValue}, {rightValue})' self.ARCTAN2 = 'atan2({leftValue}, {rightValue})' # Current time in simulation self.TIME = 'time' # Internal data: model will be set by the analyzer self.modelCanonicalName = None """ formatQuantity(), formatQuantity(), and formatNumpyArray() are commonly defined in derived classes """
[docs] def formatQuantity(self, quantity): # Formats constants/quantities in equations that have a value and units return '{{{0} {1}}}'.format(quantity.value, self.formatUnits(quantity.units))
[docs] def formatUnits(self, units): # Format: m kg^2/(s^2) meaning m * kg**2 / s**2 positive = [] negative = [] for u, exp in list(units.toDict().items()): if exp >= 0: if exp == 1: positive.append('{0}'.format(u)) elif int(exp) == exp: positive.append('{0}^{1}'.format(u, int(exp))) else: positive.append('{0}^{1}'.format(u, exp)) for u, exp in list(units.toDict().items()): if exp < 0: if exp == -1: negative.append('{0}'.format(u)) elif int(exp) == exp: negative.append('{0}^{1}'.format(u, int(math.fabs(exp)))) else: negative.append('{0}^{1}'.format(u, math.fabs(exp))) if len(positive) == 0: sPositive = 'rad' else: sPositive = ' '.join(positive) if len(negative) == 0: sNegative = '' elif len(negative) == 1: sNegative = '/' + ' '.join(negative) else: sNegative = '/(' + ' '.join(negative) + ')' return sPositive + sNegative
[docs] def formatNumpyArray(self, arr): if isinstance(arr, (numpy.ndarray, list)): return '[' + ', '.join([self.formatNumpyArray(val) for val in arr]) + ']' else: return str(arr)
[docs] def formatIdentifier(self, identifier): # Removes illegal characters from domains/parameters/variables/ports/models/... names return identifier.replace('&', '').replace(';', '').replace('(', '_').replace(')', '_').replace(',', '_').replace(' ', '').replace('{', '').replace('}', '').replace('\\', '')
[docs] def flattenIdentifier(self, identifier): # Removes illegal characters from domains/parameters/variables/ports/models/... names return identifier.replace('.', '_').replace('(', '_').replace(')', '_').replace('[', '_').replace(']', '_').replace(',', '_').replace(' ', '').replace('{', '').replace('}', '').replace('\\', '')
[docs] def formatDomain(self, domainCanonicalName, index, value): # ACHTUNG, ACHTUNG!! Take care of indexing of the domain index if self.useRelativeNames: name = daeGetRelativeName(self.modelCanonicalName, domainCanonicalName) else: name = domainCanonicalName # Always remove illegal characters name = self.formatIdentifier(name) if self.flattenIdentifiers: name = self.flattenIdentifier(name) indexes = str(index + self.indexBase) res = self.domain.format(domain = name, index = indexes, value = value) return res
[docs] def formatParameter(self, parameterCanonicalName, domainIndexes, value): # ACHTUNG, ACHTUNG!! Take care of indexing of the domainIndexes if self.useRelativeNames: name = daeGetRelativeName(self.modelCanonicalName, parameterCanonicalName) else: name = parameterCanonicalName # Always remove illegal characters name = self.formatIdentifier(name) if self.flattenIdentifiers: name = self.flattenIdentifier(name) domainindexes = '' if len(domainIndexes) > 0: domainindexes = self.parameterIndexStart + self.parameterIndexDelimiter.join([str(di+self.indexBase) for di in domainIndexes]) + self.parameterIndexEnd res = self.parameter.format(parameter = name, indexes = domainindexes, value = value) return res
[docs] def formatVariable(self, variableCanonicalName, domainIndexes, overallIndex): # ACHTUNG, ACHTUNG!! Take care of indexing of the overallIndex and the domainIndexes overall_ = overallIndex + self.indexBase if overallIndex in self.indexMap: block_ = self.indexMap[overallIndex] + self.indexBase else: block_ = -1 if self.useFlattenedNamesForAssignedVariables and (self.IDs[overallIndex] == cnAssigned): name = daeGetRelativeName(self.modelCanonicalName, variableCanonicalName) # Always remove illegal characters name = self.formatIdentifier(name) # Flatten the name as requested name = self.flattenIdentifier(name) domainindexes = '' if len(domainIndexes) > 0: domainindexes = '_' + '_'.join([str(di+self.indexBase) for di in domainIndexes]) res = self.assignedVariable.format(variable = name+domainindexes, overallIndex = overall_, blockIndex = block_) else: if self.useRelativeNames: name = daeGetRelativeName(self.modelCanonicalName, variableCanonicalName) else: name = variableCanonicalName # Always remove illegal characters name = self.formatIdentifier(name) if self.flattenIdentifiers: name = self.flattenIdentifier(name) domainindexes = '' if len(domainIndexes) > 0: domainindexes = self.variableIndexStart + self.variableIndexDelimiter.join([str(di+self.indexBase) for di in domainIndexes]) + self.variableIndexEnd res = self.variable.format(variable = name, indexes = domainindexes, overallIndex = overall_, blockIndex = block_) return res
[docs] def formatTimeDerivative(self, variableCanonicalName, domainIndexes, overallIndex): # ACHTUNG, ACHTUNG!! Take care of indexing of the overallIndex and the domainIndexes if self.useRelativeNames: name = daeGetRelativeName(self.modelCanonicalName, variableCanonicalName) else: name = variableCanonicalName # Always remove illegal characters name = self.formatIdentifier(name) if self.flattenIdentifiers: name = self.flattenIdentifier(name) overall_ = overallIndex + self.indexBase if overallIndex in self.indexMap: block_ = self.indexMap[overallIndex] + self.indexBase else: block_ = -1 domainindexes = '' if len(domainIndexes) > 0: domainindexes = self.derivativeIndexStart + self.derivativeIndexDelimiter.join([str(di+self.indexBase) for di in domainIndexes]) + self.derivativeIndexEnd res = self.derivative.format(variable = name, indexes = domainindexes, overallIndex = overall_, blockIndex = block_) return res
[docs] def formatRuntimeConditionNode(self, node): res = '' if isinstance(node, condUnaryNode): value = '(' + self.formatRuntimeConditionNode(node.Node) + ')' if node.LogicalOperator == eNot: res = self.NOT.format(value = value) else: raise RuntimeError('Not supported unary logical operator') elif isinstance(node, condBinaryNode): leftValue = '(' + self.formatRuntimeConditionNode(node.LNode) + ')' rightValue = '(' + self.formatRuntimeConditionNode(node.RNode) + ')' if node.LogicalOperator == eAnd: res = self.AND.format(leftValue = leftValue, rightValue = rightValue) elif node.LogicalOperator == eOr: res = self.OR.format(leftValue = leftValue, rightValue = rightValue) else: raise RuntimeError('Not supported binary logical operator') elif isinstance(node, condExpressionNode): leftValue = '(' + self.formatRuntimeNode(node.LNode) + ')' rightValue = '(' + self.formatRuntimeNode(node.RNode) + ')' if node.ConditionType == eNotEQ: # != res = self.NEQ.format(leftValue = leftValue, rightValue = rightValue) elif node.ConditionType == eEQ: # == res = self.EQ.format(leftValue = leftValue, rightValue = rightValue) elif node.ConditionType == eGT: # > res = self.GT.format(leftValue = leftValue, rightValue = rightValue) elif node.ConditionType == eGTEQ: # >= res = self.GTEQ.format(leftValue = leftValue, rightValue = rightValue) elif node.ConditionType == eLT: # < res = self.LT.format(leftValue = leftValue, rightValue = rightValue) elif node.ConditionType == eLTEQ: # <= res = self.LTEQ.format(leftValue = leftValue, rightValue = rightValue) else: raise RuntimeError('Not supported condition type: %s' % node.ConditionType) else: raise RuntimeError('Not supported condition node: {0}'.format(type(node))) return res
[docs] def formatRuntimeNode(self, node): res = '' if isinstance(node, adConstantNode): value = node.Quantity.value units = self.formatUnits(node.Quantity.units) res = self.constant.format(value = value, units = units) elif isinstance(node, adTimeNode): res = self.TIME elif isinstance(node, adUnaryNode): value = '(' + self.formatRuntimeNode(node.Node) + ')' if node.Function == eSign: res = self.SIGN.format(value = value) elif node.Function == eSqrt: res = self.SQRT.format(value = value) elif node.Function == eExp: res = self.EXP.format(value = value) elif node.Function == eLog: res = self.LOG10.format(value = value) elif node.Function == eLn: res = self.LOG.format(value = value) elif node.Function == eAbs: res = self.ABS.format(value = value) elif node.Function == eSin: res = self.SIN.format(value = value) elif node.Function == eCos: res = self.COS.format(value = value) elif node.Function == eTan: res = self.TAN.format(value = value) elif node.Function == eArcSin: res = self.ASIN.format(value = value) elif node.Function == eArcCos: res = self.ACOS.format(value = value) elif node.Function == eArcTan: res = self.ATAN.format(value = value) elif node.Function == eCeil: res = self.CEIL.format(value = value) elif node.Function == eFloor: res = self.FLOOR.format(value = value) elif node.Function == eSinh: res = self.SINH.format(value = value) elif node.Function == eCosh: res = self.COSH.format(value = value) elif node.Function == eTanh: res = self.TANH.format(value = value) elif node.Function == eArcSinh: res = self.ASINH.format(value = value) elif node.Function == eArcCosh: res = self.ACOSH.format(value = value) elif node.Function == eArcTanh: res = self.ATANH.format(value = value) elif node.Function == eErf: res = self.ERF.format(value = value) else: raise RuntimeError('Not supported unary function: %s' % node.Function) elif isinstance(node, adBinaryNode): leftValue = '(' + self.formatRuntimeNode(node.LNode) + ')' rightValue = '(' + self.formatRuntimeNode(node.RNode) + ')' if node.Function == ePlus: res = self.PLUS.format(leftValue = leftValue, rightValue = rightValue) elif node.Function == eMinus: res = self.MINUS.format(leftValue = leftValue, rightValue = rightValue) elif node.Function == eMulti: res = self.MULTI.format(leftValue = leftValue, rightValue = rightValue) elif node.Function == eDivide: res = self.DIVIDE.format(leftValue = leftValue, rightValue = rightValue) elif node.Function == ePower: res = self.POWER.format(leftValue = leftValue, rightValue = rightValue) elif node.Function == eMin: res = self.MIN.format(leftValue = leftValue, rightValue = rightValue) elif node.Function == eMax: res = self.MAX.format(leftValue = leftValue, rightValue = rightValue) elif node.Function == eArcTan2: res = self.ARCTAN2.format(leftValue = leftValue, rightValue = rightValue) else: raise RuntimeError('Not supported binary function: %s' % node.Function) elif isinstance(node, adScalarExternalFunctionNode): name = node.ExternalFunction.Name res = self.scalarExternalFunction.format(name = name) elif isinstance(node, adVectorExternalFunctionNode): name = node.ExternalFunction.Name res = self.vectorExternalFunction.format(name = name) elif isinstance(node, adDomainIndexNode): res = self.formatDomain(node.Domain.CanonicalName, node.Index, node.Value) elif isinstance(node, adRuntimeParameterNode): res = self.formatParameter(node.Parameter.CanonicalName, node.DomainIndexes, node.Value) elif isinstance(node, adRuntimeVariableNode): res = self.formatVariable(node.Variable.CanonicalName, node.DomainIndexes, node.OverallIndex) elif isinstance(node, adRuntimeTimeDerivativeNode): res = self.formatTimeDerivative(node.Variable.CanonicalName, node.DomainIndexes, node.OverallIndex) elif isinstance(node, adFEMatrixItemNode): #raise RuntimeError('Finite Elements equations are not supported for code generation, node: %s' % type(node)) res = self.feMatrixItem.format(matrixName = node.MatrixName, row = node.Row, column = node.Column, value = node.Value) elif isinstance(node, adFEVectorItemNode): #raise RuntimeError('Finite Elements equations are not supported for code generation, node: %s' % type(node)) res = self.feVectorItem.format(vectorName = node.VectorName, row = node.Row, value = node.Value) else: raise RuntimeError('Not supported node: %s' % type(node)) return res