""" symbolic module
===============================================================================
Overview
===============================================================================
A framework for declaring and manipulating symbolic expressions. Suitable for
automatic differentiation and developing computation graph models of functions
for use in optimization.
Inspired by Sympy, SCIP models, Gurobi models, as well as TensorFlow / PyTorch,
this framework has the following features:
- Saving and loading of expressions is fully supported. Strings and JSON
formats are supported.
- Supports black-box functions inside expressions.
- Allows references to arbitrary non-serializable user data.
- Allows new functions to be user-defined in a straightforward way.
- Symbolic auto-differentiation.
- Conversion to and from Sympy.
- Numeric, vector, and matrix variables are allowed. Many native numpy
operations are supported.
Could it be used for learning? Yes, but it is not as fast as GPU-based
learning libraries like TensorFlow / PyTorch.
How does it compare to Sympy? This module does a better job of handling large
matrix expressions and variable indexing. For example, the multiplication of 5
symbolic 2x2 matrices takes ~200 microseconds, which is a bit slower than
direct numpy evaluation of ~50 microseconds. On the other hand, the Sympy
expression takes over 4s to construct, and each entry of the matrix contains
hundreds of terms!
This module can also take derivatives with respect to vectors using matrix
calculus. It can even auto-differentiate with nested lists/arrays! However,
function simplification is not as complete as Sympy, and derivatives for some
functions are not available. There are conversions to and from Sympy in the
symbolic_sympy module, and these are complete for most "tame" expressions.
===============================================================================
Basic usage
===============================================================================
In standard usage, first create a :class:`Context` and declare any extra
functions to be used in your library. Then, and add variables and declare
expressions as necessary.::
ctx = Context()
x = ctx.addVar("x",'N')
y = ctx.addVar("y",'N')
print(x + y) #(x+y) is an Expression object, which has not been evaluated yet
print(flatten(x,y)) #flatten(x,y) is an Expression object, which has not been evaluated yet
print((x + y).eval({'x':5,'y':3})) #prints 8 (a ConstantExpression)
print(flatten(x,y).eval({'x':5,'y':3})) #prints [5,3] (a ConstantExpression
print(flatten(x,y).eval({'y':[1,2]})) #prints [x,1,2] (an Expression object)
x.bind(5) #x is now treated as a constant
y.bind(3) #y is now treated as a constant
print((x+y).eval()) #prints 8 (a ConstantExpression)
print((x+y).evalf()) #prints 8 (as an integer)
x.unbind() #x is now a variable
y.unbind() #y is now a variable
An :class:`Expression` can take on a form ``x OP y`` or ``OP(x,y,z)``, where
x,y, and z are either :class:`Variable`, :class:`Expression`, or constant
values, and ``OP`` is either a builtin function or a declared :class:`Function`
type. Expressions can be evaluated to other Expressions via
:meth:`Expression.eval`, or to constants using
:meth:`Expression.evalf`.
There are two kinds of :class:`Variable`:
- *Standard Variable*: This type is managed by :class:`Context` and contains
information in the :class:`Variable` class. Such a variable can either be
numeric, vector, or matrix type. It will contain size information and will
be saved to and loaded from disk with the :class:`Context`.
- *User-data Variable*. This type is an unmanaged variable. These have
unlimited type, and can be complex Python objects, but the user must set
these up in the context manually. Hence, if you wish to reuse
Expressions of user-data variables loaded from disk, you will need to
first set up the ``Context`` appropriately with a similar user-data object.
By default, expressions constructed in Python refer to user-data objects with
strings, e.g., ``setConfig("robot",q)`` runs the
:func:`symbolic_klampt.setConfig` Function on the "robot" user data and the
``q`` Variable.
The :func:`expr` function is run on each argument to a :class:`Function`.
This will leave a :class:`Variable` and :class:`Expression` as is,
and most Python constants will be converted to a :class:`ConstantExpression`.
However, Python strings will be converted to references to user-data variables.
Constant values can be floats, integers, booleans, strings, lists, and numpy
arrays (Dictionaries are not supported as constant values.) Constants are
converted to expressions using :func:`const`, or :func:`expr`.
Type specifiers are used for type checking and to determine the type of
expressions, particularly for Jacobians. The :class:`Type` class is used for
this. The type member specifies the general type of object, and is given by a
character:
- N: generic numeric
- I: integer
- B: boolean
- X: index (either an integer or a slice)
- V: 1D vector
- M: 2D matrix
- A: n-D array
- L: list
- U: user data
- None: indicates an unknown type
as well as an optional size and sub-type. For L and V types, size is the length
of the array. For M and A types, size is the Numpy shape of the array. For
the U type, the subtype member indicates ``__class__.__name__`` if known.
===============================================================================
Standard functions
===============================================================================
Standard operations on expressions include:
- ``expr(x)``: returns an ``Expression`` corresponding to x, whether x is a
constant, ``Variable``, or ``Expression``
- ``deriv(expr,var)``: take the derivative of expr with respect to var. If no
derivative information is available for some of the expressions in expr,
returns None.
- ``simplify(expr,context=None)``: simplifies the expression in the optimal given
context. Note: this is not terribly powerful in terms of rearranging
expressions (like x - x is not converted to 0).
- ``is_const(x)``: returns True if x evaluates to a constant.
- ``to_const(x)``: returns the constant to which x evaluates, if ``is_const(x)``.
- ``is_expr(x)``: returns True if x is an ``Expression``.
- ``is_var(x)``: returns True if x is a ``Variable`` or a ``VariableExpression``
(i.e., monomial).
- ``to_var(x)``: returns the ``Variable`` corresponding to the ``Variable`` or
monomial x.
- ``is_op(x,func=None)``: returns True if x is an operator. If func is given,
checks whether this is the the name of the function.
- ``is_zero(x)``: returns True if x is the zero(...) operator or a scalar equal to
0.
- ``is_scalar(x,val=None)``: returns True if x is known to evaluate to a scalar.
If val is provided, it checks whether x is equal to val.
- ``type_of(x)``: returns the Type corresponding to x.
.. note::
These are Python functions, not symbolic Functions.
===============================================================================
Symbolic Functions
===============================================================================
You can build an :class:`Expression` by calling a symbolic :class:`Function`,
many of which duplicate the functions of plain Python, the math module, or
Numpy. Calling a :class:`Function` does not immediately evaluate
the function on its arguments, but instead builds an :class:`Expression` that
is the root of a symbolic Expression Graph.
Built-in symbolic Functions include:
Shape functions
---------------
- ``range_(n)``: Evaluates to range(n). The result can be treated like a
vector or a list.
- ``dims(x)``: Returns the dimensionality of the input. Equivalent to
``len(shape(x))``.
- ``len_(x)``: Evaluates to len(x), except if x is a scalar then it evaluates
to 0. If x is a multi-dimensional array, this is the length of its first
dimension. Undefined for other forms of user data.
- ``count(x)``: Evaluates the number of numeric parameters in x. Works with
scalars, arrays, and lists. If x is a scalar then it evaluates to 1.
Lists are evaluated recursively. Undefined for other forms of user data.
- ``shape(x)``: Evaluates to ``x.shape`` if x is a Numpy array, ``(len_(x),)``
if x is a vector, and () if x is a scalar. If x is a list, this is
``(len_(x),)+shape(item)`` if all of the items have the same shape,
and otherwise it is a hyper-shape. (shortcut: "x.shape")
- ``reshape(x,s)``: Evaluates to x reshaped to the shape s. If x is a scalar,
this evaluates to a constant matrix.
- ``transpose(x)``: Evaluates to np.transpose(x). (shortcut: "x.T")
- ``transpose2(x,axes)``: Tensor transpose; evaluates to np.transpose(x,axes).
- ``basis(i,n)``: Evaluates to the i'th elementary basis vector in dimension n.
- ``eye(n)``: Evaluates to np.eye(n) if n > 0, otherwise returns 1.
- ``zero(s)``: Evaluates to np.zeros(s) if s is a matrix shape or scalar > 0,
otherwise evaluates to 0.
- ``diag(x)``: Evaluates to np.diag(x).
- ``flatten(*args)``: Evaluates to a vector where all arguments are stacked
(concatenated) into a single vector. Arrays are reduced to vectors by
Numpy's flatten(), and complex objects are reduced via recursive flattening.
- ``row_stack(*args)``: Evaluates to a matrix where all arguments are stacked
vertically. 1D arrays are treated as row vectors. If only a single list
element is provided, then all arguments are stacked.
- ``column_stack(*args)``: Evaluates to a matrix where all arguments are stacked
horizontally. 1D arrays are treated as column vectors. If only a single
list element is provided, then all arguments are stacked.
Comparisons and logical functions
----------------------------------
- ``eq(lhs,rhs)``: Evaluates to lhs = rhs (shortcut: "lhs = rhs").
- ``ne(lhs,rhs)``: Evaluates to lhs != rhs (shortcut: "lhs != rhs").
- ``le(lhs,rhs)``: Evaluates to lhs <= rhs (shortcut: "lhs <= rhs").
- ``ge(lhs,rhs)``: Evaluates to lhs >= rhs (shortcut: "lhs >= rhs").
- ``not_(x)``: Evaluates to not x (shortcut: "not x").
- ``or_(x,y)``: Evaluates to x or y (shortcut: "x or y").
- ``and_(x,y)``: Evaluates to x and y (shortcut: "x and y").
- ``any_(*args)``: Evaluates to ``any(*args)``
- ``all_(*args)``: Evaluates to ``all(*args)``
Arithmetic functions
---------------------
- ``neg(x)``: Evaluates to -x (shorcut "-x").
- ``abs_(x)``: Evaluates to abs(x) (shortcut: "abs(x)"). Works with arrays too
(elementwise)
- ``sign(x)``: Evaluates the sign of x. Works with arrays too (elementwise).
- ``add(x,y)``: Evaluates to x + y (shortcut: "x + y"). Works with arrays too.
- ``sub(x,y)``: Evaluates to x - y (shortcut: "x - y"). Works with arrays too,
and vector - scalar.
- ``mul(x,y)``: Evaluates to x * y (shortcut: "x * y"). Works with arrays too
(elementwise multiplication).
- ``div(x,y)``: Evaluates to x / y (shortcut: "x / y"). Works with arrays too
(elementwise division), and vector / scalar.
- ``pow_(x,y)``: Evaluates to ``pow(x,y)`` (shortcut "x**y").
- ``dot(x,y)``: Evaluates to ``np.dot(x,y)``.
- ``outer(x,y)``: Evaluates to ``np.outer(x,y)``.
- ``tensordot(x,y,axes)``: Evaluates to ``np.tensordot(x,y,axes)``.
- ``max_(*args)``: Evaluates to the maximum of the arguments.
- ``min_(*args)``: Evaluates to the minimum of the arguments.
- ``argmax(*args)``: Evaluates to the index of the maximum of the argments.
- ``argmin(*args)``: Evaluates to the index of the minimum of the argments.
- ``cos(x)``: Evaluates to math.cos(x).
- ``sin(x)``: Evaluates to math.sin(x).
- ``tan(x)``: Evaluates to math.tan(x).
- ``sqrt(x)``: Evaluates to math.sqrt(x).
- ``exp(x)``: Evaluates to math.exp(x).
- ``log(x)``: Evaluates to math.log(x) (base 10).
- ``ln(x)``: Evaluates to math.ln(x) (natural log).
- ``sum_(*args)``: Evaluates to sum(args). If arguments are vectors or matrices,
then the result is also a vector or matrix. This is somewhat different
behavior from sum(x) if x is a list.
- ``weightedsum(v1,...,vn,w1,...,wn)``: Evaluates to :math:`w1*v1+...+wn*vn`.
Accessors
---------
- ``getitem(vec,index)``: Evaluates to ``vec[index]``. This also supports slices
and tuples, as well as lists (Numpy fancy indexing). (shortcut:
"vec[index]")
- ``setitem(vec,index,val)``: Evaluates to vec except with vec[index] set to val.
Equivalent to Python code::
temp = vec[:]
vec[index]=val
return temp
- ``getattr_(object,attr)``: returns the value of a given attribute under the
given object. For example, ``getattr_(traj,const("milestones"))`` gets the
milestone list of a user-data ``Trajectory`` named ``traj``.
If the result is a function, it will be called with no arguments. For
example, ``getattr_(robot,const("getJointLimits"))`` will return the robot's
joint limits.
(shortcut: "object.attr", where object is a UserDataExpression)
- ``setattr_(object,attr,val)``: returns a modified version of the given class
object, where ``value`` is assigned to the attribute ``attr``. For example,
``setattr_(traj,const("times"),[0,0.5,0.1])`` sets the ``times`` attribute
of a ``Trajectory`` to [0,0.5,1].
Note: this operation modifies the object itself.
If the attribute ``attr`` is a function, it will be called with the argument
``val``. This allows setting operations to be called.
Conditionals
------------
- ``if_(cond,trueval,falseval)``: If cond evaluates to True, this expression
evaluates to trueval. Otherwise, it evaluates to falseval.
Arrays
-------
Arrays are mostly interchangable with vectors (numpy 1-D arrays), except they
can contain objects of varying type/size. Arrays of varying type/size should
only be used as arguments in the flatten or the special looping functions
(``forall``, ``summation``, etc). In many cases the Python construction
``[e1,e2,...,en]`` where en is an ``Expression``, will be interpreted correctly
as an ``array`` Expression. To handle nested lists of Expressions, and
to ensure that your lists can be saved and loaded, you may need to use these
functions.
- ``array(*args)``: Creates a list or numpy array of the given arguments. This can
accept arbitrary arguments, such as variable size vectors. A Numpy array is
produced only if the items have compatible types and dimensions.
- ``list_(*args)``: Creates a list of the given arguments. This can accept
arbitrary arguments, such as variable size vectors. No attempt is made to
convert to a numpy array.
- ``tuple_(*args)``: Creates a tuple of the given arguments. This can accept
arbitrary arguments, such as variable size vectors.
- ``zip_(collection1,collection1,...)``: Does the same thing as the Python zip
function, returning a list of tuples.
Looping
--------
Special functions are available for temporary variable substitution and
emulation of ``for`` loops. In each of the following, var can be a variable
or userData referenced in the ``Expression`` ``expr``:
- ``subs(expr,var,value)``: evaluates expr with var substituted with value. For
example, ``subs(const(2)*"i","i",3.5)`` yields 2*3.5
- ``map_(expr,var,values)``: like the Python ``map`` function, evaluates to a
list where each entry evaluates ``expr`` with ``var`` substituted with
a value from the list ``values``. For example, if x is a ``Variable``, then
``map_(x**"i","i",range_(3))`` yields the list ``[x**0, x**1, x**2]``
- ``forall(expr,var,values)``: True if, for every value in the list ``values``,
``expr`` evaluates to nonzero when ``var`` is substituted with that value.
Equivalent to ``all_(*[subs(expr,var,value) for value in values])``
- ``forsome(expr,var,values)``: True if, for some value in the list ``values``,
``expr`` evaluates to nonzero when ``var`` is substituted with that value.
Equivalent to ``any_(*[subs(expr,var,value) for value in values])``
- ``summation(expr,var,values)``: The sum of ``expr`` over ``var`` when ``var``
is substituted with each value in the list ``values``. Equivalent to
``sum_(*[subs(expr,var,value) for value in values])``
Nested iteration can be performed, such as
``summation(summation(expr("x")**"y","x",[1,2]),"y",[0,1])``
===============================================================================
Defining your own Functions
===============================================================================
If you want to use your own Python functions, you can use ``Context.declare``.
By default this will use the same function signature that you used to define
the function, or you can provide a new signature (name, arguments).
The convention used in the built-in symbolic libraries is to prepend an
underscore to the underlying Python function, and then declare the function
to the undecorated name.::
def _f(x,y):
return 2*x*y
ctx = Context()
f = ctx.declare(_f,"f")
print(f(3,4)) #prints f(3,4)
print(f(3,4).evalf()) #prints 24, since 2*3*4=12.
At this point, the module does not know how to take derivatives, so any
``deriv(f(...),arg)`` call will return None. To set derivatives, you can
use the ``setDeriv`` or ``setJacobian`` functions::
f.setDeriv("x",(lambda x,y,dx:2*dx*y))
f.setDeriv("y",(lambda x,y,dy:2*x*dy))
OR::
f.setJacobian("x",(lambda x,y:2*y))
f.setJacobian("y",(lambda x,y:2*x))
Alternatively, you can declare an ``Expression`` as a ``Function``. Here you
need to define the function name and argument list so that unbound variables in
the ``Expression`` are bound to the arguments.::
f = ctx.declare(const(2)*expr("x")*expr("y"),"f",["x","y"])
This form will automatically obtain derivatives for you.
Functions can also specify argument types and return types, which are used
for type checking at Expression creation time. This is quite helpful for
debugging, since you do not have to evaluate the ``Expression`` to find
mismatched types, like indexing a vector with a list. To declare types for
custom functions, use the ``setArgType`` and ``setReturnType`` methods::
#numeric input and output
f.setArgType(0,'N')
f.setArgType(1,'N')
f.setReturnType('N')
===============================================================================
Gotchas
===============================================================================
Most conversions of arguments to Expressions are done for you during the
course of calling operators or calls to symbolic functions, but if you call
operators on Python constants, you will invoke the standard Python operation.
For example, 5 + 7 produces 12 because 5 and 7 are Python constants, but if you
want to build the symbolic ``Expression`` that represents "5 + 7", you will need
to use ``const(5)+const(7)`` or ``expr(5)+expr(7)``.
Since raw strings are interpreted as references to user-data variables, to
create string constants, you will need to wrap the string using the ``const()``
function, e.g., ``const("text")``. This is mostly relevant for the
``get/setattr_`` functions. Suppose Point is a class with members Point.x and
Point.y, the expression that accesses the x member of the user data ``Point p``
is built with ``getattr_("p",const("x"))``. Using ``__getattr__`` operator
overloading you can also use ``expr("p").x``
Python Lists as arguments are handled in somewhat of a tricky way. If the list
is compatible with a Numpy array (i.e., each element is numeric or equal-sized
arrays), it will be converted to a Numpy array. Otherwise, (e.g., a non-
constant Expression is inside or the elements do not have the same size) it will
be converted to an ``array(e1,...,en)`` Expression. Tuples are NOT
converted in the same way, so that Numpy matrix indices can be preserved.
Tuples of expressions are not supported implicitly, instead you will have to use
``tuple_(e1,...,en)``.
Comparison testing on Expressions, such as ``e != 0`` or ``e1 == e2``, do
not return True or False. Instead, they return Expressions themselves! As
a result, standard Python if statements cannot be used on Expressions. Instead,
you will need to run ``eval()`` or ``evalf()`` on the result if you want to get
its truth value. (``eval()`` does return an Expression, but if the result of
``eval()`` is a ConstantExpression then it may be directly tested for its truth
value.) To help you remember this, an exception will be raised if you try to use
an expression as a condition in a Python ``if`` statement.
Note also that ``e1 == e2`` does NOT test for whether the two expressions are
logically equivalent. As an example, the expression ``x1+x2 == x2+x1`` does not
test equivalence as you might expect, nor would even ``eval()`` to true unless
x1 and x2 were given. In order to test whether two expressions are
*syntactically* equivalent, you can use ``e1.match(e2)``. It is computationally
intractable in general to determine whether two expressions are logically
equivalent, so we don't even try.
Derivatives are handled as usual for scalar/vector quantities, giving Jacobian
matrices in the form::
[df1/dx1 ... df1/dxn ]
[... ]
[dfm/dx1 ... dfm/dxn ].
However, if the function or variable has an "exotic" type, like a matrix,
tensor, or nested list, then the Jacobian is taken with respect to the
**flattened** version of the function and the variable. The way this is
handled can be seen in some complex derivative expressions, which will have
various reshape operations.
Many standard Python operators -- +,-,*,/,**, and, or, not, comparison tests,
the [] indexing operator, and list construction via [a,b] -- are supported
directly on Variables and Expressions. Compound slices (``matrix[:,:]``) have
not yet been thoroughly tested. Other standard functions -- len, min, max, sum,
abs, any, and all -- have direct analogues in this package, with a trailing
underscore added. If statements and list comprehensions can be emulated via the
``if_(cond,trueval,falseval)`` statement and the ``map_(expr,var,range)
statement``. There is no support for the in statement, bitwise operators, or
procedural code (e.g., statements with side-effects).
If-elif-...-else blocks can be emulated using a multiplexer
``array(expr1,expr2,...,exprn)[switch]`` where ``switch`` takes on the values
0,...,n-1. This expression is lazy-evaluated so that only the selected
expression is expanded. Standard ``if_`` statements are also lazy-evaluated.
Thread-safety: Expression evaluation, derivatives, and simplification are NOT
thread safe. Common Expressions used between threads should be deep-copied
before use.
===============================================================================
Performance and Expression DAGs
===============================================================================
If your expression repeatedly performs computationally expensive operations, it
is helpful to gather sub-expressions into the same Python objects. As an
example, consider a Variable a, and the expression d given by the following
construction::
b = (a+a)
c = (b+b)
d = (c+c)
Expanded, d = (c+c) = ((b+b)+(b+b)) = (((a+a)+(a+a))+((a+a)+(a+a)). However,
when ``d.eval({'a':1})`` is called, this module is smart enough to evaluate a,
b, and c only once. This becomes really helpful in derivative evaluation as
well, since the derivative expression will also try to reuse common sub-
expressions.
Now what if we want to evaluate several expressions at once? This naive code
will evaluate a, b, c, and d 1000 times::
ds = [(d+i).eval({'a':1}) for i in range(1000)]
In order to take advantage of the sub-expression d, we can put everything into a
single ``array()`` expression::
ds = array(*[(d+i) for i in range(1000)]).eval({'a':1})
which will only evaluate a,b,c, and d once.
===============================================================================
IO
===============================================================================
There are three forms of Expression IO in symbolic_io:
1. Standard Python printing: ``str(expr)`` (output only)
2. Human-readable strings: :func:`symbolic_io.toStr` (output) /
:func:`symbolic_io.fromStr` (input)
3. JSON-like objects: :func:`symbolic_io.toJson` (output) /
:func:`symbolic_io.fromJson` (input)
Method 1 is the most readable when printed to the console, but can't be read.
Method 2 is somewhat readable, and is compatible with the optimization problem
editor. Note that there may be some subtle problems with IO for lists vs numpy
arrays, since lists are automatically converted to numpy arrays when possible.
It hasn't been tested extremely thoroughly yet.
Method 3 is the least ambiguous but it uses a nested JSON structure to represent
the expression graph. Hence it is rather verbose.
In Methods 2 and 3, if the expression graph is a DAG rather than a tree, a
special notation is used to represent repeated references to nodes. In
particular, any common ConstantExpression and OperatorExpression objects are
"tagged" with a unique id string.
In Method 2, the first time a sub-expression appears it's tagged with a suffix
#id. Then, subsequent references are retrieved with the expression @id. As an
example, the string ``(x+y)#1*@1`` evaluates to ``(x+y)*(x+y)``, since '#1'
assigns ``(x+y)`` to the id '1' and '@1' retrieves it.
===============================================================================
Code generation
===============================================================================
- context.makePyFunction(expr,varorder=None): creates a Python function that
evaluates ``expr`` given arguments whose values are to be assigned to each
``Variable`` in ``varorder``.
- context.makeFlatFunction(expr,varorder=None): creates a Python function that
evaluates ``expr`` given a flattened list or vector of parameters, which can
be unraveled to values to be assigned to each Variable in ``varorder``.
- symbolic_io.latex(expr): produces a LaTex string representing the expression
(requires Sympy)
- symbolic_io.codegen(exprs,language): produces C / Matlab code to evaluate the
``Function`` or named expressions (requires Sympy).
===============================================================================
Sympy integration
===============================================================================
Sympy integration is quite automatic. See :func:`symbolic_sympy.exprToSympy`
and :func:`symbolic_sympy.exprFromSympy` in symbolic_sympy.py. Built-in
functions are converted mostly bidirectionally and hence support all aspects of
both libraries.
Array operations are converted to Sympy Matrix operations, which expand all
entries into scalar expressions, so array operations like dot() cannot be
converted to Sympy and then back into the equvalent symbolic operation.
Custom symbolic functions are converted to sympy Function(f) objects, and it
should be noted that these do not support code generation.
Special Sympy functions are automatically converted for use in the symbolic
module, including differentiation. (Note that they cannot be saved to / loaded
from disk, though.)
===============================================================================
Wish list
===============================================================================
- Sparse matrix support when obtaining jacobians -- especially in block matrix
form.
- Compiled code generation on vectors/matrices -- maybe integration with
TensorFlow / PyTorch?
===============================================================================
Module summary
===============================================================================
External interface
-----------------------
.. autosummary::
expr
Type
deriv
simplify
supertype
type_of
const
is_const
to_const
is_scalar
to_scalar
is_zero
is_var
to_var
is_expr
is_op
is_sparse
to_monomial
to_polynomial
Internal implementation
-----------------------
.. autosummary::
Type
Context
Function
Variable
Wildcard
Expression
ConstantExpression
UserDataExpression
VariableExpression
OperatorExpression
Symbolic standard functions
---------------------------
.. autosummary::
range_
len_
count
shape
reshape
transpose
transpose2
dims
eye
basis
zero
diag
eq
ne
le
ge
not_
or_
and_
neg
abs_
sign
add
sub
mul
div
pow_
dot
outer
tensordot
if_
max_
min_
argmax
argmin
cos
sin
tan
arccos
arcsin
arctan
arctan2
sqrt
exp
log
ln
sum_
any_
all_
getitem
setitem
getattr_
setattr_
flatten
row_stack
column_stack
array
list_
tuple_
zip_
weightedsum
subs
map_
forall
forsome
summation
"""
import weakref
from . import vectorops
import operator
import numpy as np
import math
import copy
import itertools
import warnings
from builtins import object
_DEBUG_CACHE = False
_DEBUG_DERIVATIVES = False
_DEBUG_SIMPLIFY = False
_DEBUG_TRAVERSE = False
#all recursion items to debug. Warning: don't include 'depth' or None
_DEBUG_TRAVERSE_ITEMS = {}
SCALAR_TYPES = 'NIB'
ARRAY_TYPES = 'AVM'
BUILTIN_TYPES = 'NIBAVMX'
VAR_DERIV_PREFIX = 'D'
TYPE_CHECKING = True
SHAPE_CHECKING = True
_PY_INT_TYPES = (int,np.int32,np.int64)
_PY_FLOAT_TYPES = (float,np.float32,np.float64)
_PY_NUMERIC_TYPES = _PY_INT_TYPES + _PY_FLOAT_TYPES
_PY_PRIMITIVE_TYPES = _PY_NUMERIC_TYPES + (bool,str)
_PY_CONST_TYPES = _PY_INT_TYPES + _PY_FLOAT_TYPES + (bool,np.ndarray)
[docs]class Type:
"""A specification of a variable/expression type.
Attributes:
char (str): a character defining the type. Valid values are
N, I, B, X, A, V, M, L, U, and None.
size: None (no size specified), the length of the array (L
and V types), or the Numpy shape (A, M types)
subtype (str, Type, or list of Types, optional):
* U type: the name of the class, if known.
* L type: either a single Type defining the list's element
types, or a list of Types defining individual element
types (of the same length as size).
"""
def __init__(self,type,size=None,subtype=None):
if isinstance(type,Type):
self.char = type.char
self.size = type.size if size is None else size
if subtype is None:
self.subtype = type.subtype
elif isinstance(subtype,list):
self.subtype = [Type(s) for s in subtype]
else:
self.subtype = Type(subtype)
else:
if type is not None:
assert isinstance(type,str),"Type argument "+str(type)+" must be a string, not "+type.__class__.__name__
if type is not None and len(type) > 1:
#assume it's a user type
self.char = 'U'
self.size = size
self.subtype = type
else:
self.char = type
self.size = size
if subtype is None:
self.subtype = subtype
elif isinstance(subtype,list):
self.subtype = [Type(s) for s in subtype]
else:
self.subtype = Type(subtype)
assert self.size is None or isinstance(self.size,_PY_INT_TYPES) or isinstance(self.size,(list,tuple,dict,np.ndarray)),"Erroneous type of size? "+str(size)+": "+size.__class__.__name__
if isinstance(self.size,dict):
assert self.char in 'AL',"Hyper-shape objects must have A or L type"
self.char = 'L'
subtype = self.subtype
size = self.size
self.size = len(self.size)
self.subtype = [None]*self.size
for (k,v) in size.items():
if isinstance(v,dict):
self.subtype[k] = Type('L',v,subtype)
else:
self.subtype[k] = Type('A',v,subtype)
elif isinstance(self.size,np.ndarray):
self.size = tuple(self.size.astype(np.int64))
elif isinstance(self.size,list):
self.size = tuple(self.size)
if self.char == 'A' and isinstance(self.size,_PY_INT_TYPES):
self.size = (self.size,)
if self.char is not None and self.char in 'AM':
assert not isinstance(self.size,_PY_INT_TYPES)
[docs] def is_scalar(self):
if self.char is None: return False
return self.char in SCALAR_TYPES
[docs] def shape(self,hypershape=True):
"""Returns the Numpy shape or hypershape of this type. If it's array-
like, the return value is a tuple. Otherwise, it's a dict.
A ValueError is raised if no size is specified.
"""
if self.char is None:
raise ValueError("Item of unknown type does not map to a Numpy shape")
if self.char in SCALAR_TYPES: return ()
if self.size is None:
raise ValueError("Item of type "+self.char+" does not have a defined size")
elif self.char == 'V': return (self.size,)
elif self.char in ['M','A']: return self.size
if self.char == 'L' and hypershape:
if self.subtype is not None:
res = {}
if isinstance(self.subtype,list):
for i,st in enumerate(self.subtype):
res[i] = st.shape()
else:
stsh = self.subtype.shape()
for i in range(self.size):
res[i] = stsh
return res
raise ValueError("Item of type "+self.char+" does not map to a Numpy shape")
[docs] def len(self):
"""Returns the number of first-level sub-items of this type. None is returned if no
size is specified."""
if self.char is None:
raise ValueError("Item of unknown type does not have a len")
if self.char in SCALAR_TYPES: return 0
elif self.char == 'V': return self.size
elif self.char == 'L':
return self.size
elif self.char in 'MA':
return self.size[0]
else:
raise ValueError("Type "+self.char+" has no len")
[docs] def count(self):
"""Returns the number of free parameters in this type. None is returned if no
size is specified."""
if self.char is None:
raise ValueError("Item of unknown type does not have a count")
if self.char in SCALAR_TYPES: return 1
if self.char == 'L':
try:
if self.subtype is not None:
if isinstance(self.subtype,list):
res = sum(s.count() for s in self.subtype)
assert isinstance(res,int),"Result of count isn't an integer? %s"%(res.__class__.__name__,)
return res
elif self.subtype.len() is not None:
res = self.size*self.subtype.count()
assert isinstance(res,int),"Result of count isn't an integer? %s"%(res.__class__.__name__,)
return res
return None #no length specified
except Exception:
return None
return self.size
if self.char in 'MA':
if self.size is None: return None
return int(np.product(self.size))
assert isinstance(self.len(),int),"Result of count isn't an integer? %s"%(self.len().__class__.__name__,)
return self.len()
[docs] def dims(self):
"""Returns the number of entries in the shape of this type. For compound list types, this is
1+subtype.dims()."""
if self.char is None:
raise ValueError("Item of unknown type does not have a dimensionality")
if self.char in SCALAR_TYPES: return 0
elif self.char == 'V': return 1
elif self.char == 'M': return 2
elif self.char == 'A':
if self.size is None: raise ValueError("Array of unspecified shape does not have a defined dimensionality")
return len(self.size)
elif self.char == 'L':
if self.subtype is None:
raise ValueError("List of unspecified subtype does not have a defined dimensionality")
if isinstance(self.subtype,list):
return 1 + max(st.dims() for st in self.subtype)
return 1 + self.subtype.dims()
raise ValueError("Item of type "+self.char+" does not have a dimensionality")
[docs] def match(self,obj,strict=False):
"""Returns true if the given object type matches all aspects of this specification.
Integers and booleans match with numeric, vectors and matrices match with array.
obj is a Type, str, or Python object. If obj is a str, it is either a character specifier
or a class name, and sizes aren't checked. If obj is an object this is equivalent to
self.match(type_of(obj)).
If obj doesn't have a size then it matches only strict = False, or if this
specification also doesn't have a size.
"""
if isinstance(obj,Type):
if obj.char is None and not strict: return True
if self.char is None: return True
elif obj.char == 'U':
if self.char == 'U':
return self.subtype is None or obj.subtype is None or self.subtype == obj.subtype
else:
return obj.subtype is None or self.match(obj.subtype,strict)
elif self.char == 'N' and obj.char in SCALAR_TYPES: return True
elif self.char == 'A' and obj.char in ARRAY_TYPES:
if self.size is None: return True
if obj.char == 'V': return self.size[0] == obj.size
return self.size == obj.size
elif obj.char == 'A' and self.char in ARRAY_TYPES:
if obj.size is None: return True
if self.char == 'V': return obj.size[0] == self.size
return self.size == obj.size
elif self.char == 'L':
if obj.char in ARRAY_TYPES:
if self.size is None: return True
if obj.size is None: return not strict
assert isinstance(self.size,_PY_INT_TYPES),"Invalid size of list: "+str(self.size)
if obj.len() != self.size: return False
if self.subtype is None: return True
if isinstance(self.subtype,list):
ost = obj.itemtype()
return all(st.match(ost,strict) for st in self.subtype)
return self.subtype.match(obj.itemtype(),strict)
elif obj.char == 'L':
if self.size is None: return True
if obj.size is None: return not strict
if obj.size != self.size: return False
if self.subtype is None: return True
if isinstance(self.subtype,list):
if isinstance(obj.subtype,list):
return all(st.match(ost,strict) for st,ost in zip(self.subtype,obj.subtype))
return all(st.match(obj.subtype,strict) for st in self.subtype)
return self.subtype.match(obj.itemtype(),strict)
else:
return False
elif self.char == 'X':
return obj.char in 'VIX' or (not strict and obj.char == 'N')
if not strict and self.char in SCALAR_TYPES and obj.char in SCALAR_TYPES: return True
return self.char == obj.char and (self.size is None or (obj.size is None and not strict) or self.size == obj.size)
elif isinstance(obj,str):
if self.char is None: return True
elif self.char == obj: return True
elif self.char == 'N': return obj in SCALAR_TYPES
elif self.char == 'A': return obj in ARRAY_TYPES
elif self.char == 'L': return obj in ARRAY_TYPES
elif self.char == 'X': return obj in 'VIX'
elif self.char == 'U': return self.subtype is None or obj == self.subtype
if not strict and self.char in SCALAR_TYPES and obj in SCALAR_TYPES: return True
return False
elif obj is None: return not strict
else:
return self.match(type_of(obj),strict)
[docs] def itemtype(self,index=None):
"""If this is an array or list type, returns the Type of each item accessed by x[i].
Otherwise returns None.
"""
if self.char == 'V':
if self.subtype is None:
return Type('N')
else:
return self.subtype
elif self.char == 'M':
if self.size is None:
return Type('V',None,self.subtype)
else:
return Type('V',self.size[1],self.subtype)
elif self.char == 'A':
if self.size is None:
return Type(None)
if len(self.size) == 0: return None
elif len(self.size) == 1: return Type('N') if self.subtype is None else self.subtype
else: return Type(self,self.size[1:])
elif self.char == 'L':
if isinstance(self.subtype,list):
if index is not None:
return self.subtype[index]
return supertype(self.subtype)
return self.subtype
else:
return None
def __str__(self):
"""Returns a string representation of this type, suitable for brief printouts"""
if self.char is None:
return 'None'
if self.size is None and self.subtype is None:
return self.char
elif self.subtype is None:
return self.char+' size '+str(self.size)
elif self.size is None:
if isinstance(self.subtype,list):
return self.char+' (subtypes '+','.join(str(s) for s in self.subtype)+')'
else:
return self.char+' (subtype '+str(self.subtype)+')'
else:
if isinstance(self.subtype,list):
return self.char+' size '+str(self.size)+' (subtypes '+','.join(str(s) for s in self.subtype)+')'
else:
return self.char+' size '+str(self.size)+' (subtype '+str(self.subtype)+')'
_TYPE_DESCRIPTIONS = {'N':'numeric','I':'int','B':'bool','X':'index (int or slice)',
'V':'vector','M':'matrix','A':'array','L':'list','U':'user data'}
[docs] def info(self):
"""Returns a verbose, human readable string representation of this type"""
if self.char is None:
return 'unknown'
typedesc = "%s (%s)"%(self._TYPE_DESCRIPTIONS[self.char],self.char)
if self.size is not None:
if isinstance(self.size,(list,tuple)):
sizedesc = 'x'.join(str(v) for v in self.size)
typedesc = sizedesc + ' ' + typedesc
else:
sizedesc = str(self.size)
typedesc = sizedesc + '-' + typedesc
if self.subtype is None:
return typedesc
else:
if self.char == 'U':
return '%s of type %s'%(typedesc,self.subtype)
if isinstance(self.subtype,list):
return '%s of subtypes [%s]'%(typedesc,','.join(str(s) for s in self.subtype))
else:
return '%s of subtype (%s)'%(typedesc,self.subtype.info())
Numeric = Type('N')
Integer = Type('I')
Boolean = Type('B')
Index = Type('X')
Vector = Type('V')
Vector2 = Type('V',2)
Vector3 = Type('V',3)
Matrix = Type('M')
Matrix2 = Type('M',(2,2))
Matrix3 = Type('M',(3,3))
def _is_exactly(a,b):
"""Returns True if these are exactly the same class and match.
Works for numpy arrays, scalars, and class objects (this latter
case is matched with the 'is' keyword)
"""
if a.__class__ == b.__class__:
if isinstance(a,np.ndarray):
return np.all(a==b)
elif isinstance(a,_PY_PRIMITIVE_TYPES + _PY_FLOAT_TYPES):
return a == b
return a is b
return False
def _ravel(obj):
"""A hyper-version of numpy's ravel function that also works for nested lists"""
if isinstance(obj,np.ndarray):
return np.ravel(obj)
elif isinstance(obj,(list,tuple)):
hstack_args = [_ravel(v) for v in obj]
hstack_args = [v for v in hstack_args if len(v) > 0]
if len(hstack_args) == 0:
return np.array([])
elif len(hstack_args) == 1:
return hstack_args[0]
try:
return np.hstack(hstack_args)
except ValueError as e:
return np.array([])
else:
return np.array([obj])
def _hyper_count(obj):
if isinstance(obj,np.ndarray):
return np.product(obj.shape)
elif isinstance(obj,(list,tuple)):
return sum(_hyper_count(v) for v in obj)
else:
return 1
def _has_hyper_shape(obj):
if isinstance(obj,np.ndarray):
return False
elif isinstance(obj,(list,tuple)):
if any(_has_hyper_shape(v) for v in obj): return True
subshapes = [_hyper_shape(v) for v in obj]
if len(subshapes) == 0: return False
if all(isinstance(s,tuple) and s==subshapes[0] for s in subshapes):
return False
return True
else:
return ()
def _hyper_shape(obj):
if isinstance(obj,np.ndarray):
return obj.shape
elif isinstance(obj,(list,tuple)):
subshapes = [_hyper_shape(v) for v in obj]
if len(subshapes) == 0: return ()
if all(isinstance(s,tuple) and s==subshapes[0] for s in subshapes):
return (len(obj),)+subshapes[0]
return dict([(i,s) for i,s in enumerate(subshapes)])
else:
return ()
def _is_numpy_shape(sh):
if isinstance(sh,(tuple,list)):
return (len(sh)==0 or all(isinstance(x,_PY_INT_TYPES) for x in sh))
elif isinstance(sh,np.ndarray):
return len(sh.shape) == 1 and sh.dtype.char != 'O'
return False
def _is_hyper_shape(sh):
if isinstance(sh,(tuple,list)):
return len(sh)==0 or all(isinstance(x,_PY_INT_TYPES) for x in sh)
elif isinstance(sh,np.ndarray):
return len(sh.shape) == 1 and sh.dtype.char != 'O'
elif isinstance(sh,dict):
for (k,v) in sh.items():
if not isinstance(k,_PY_INT_TYPES): return False
if not _is_hyper_shape(v): return False
return True
def _concat_hyper_shape(sh1,sh2):
"""Describes the shape of the object where every primitive element of the type described by sh1
is replaced by an element of the type described by sh2. In other words, performs the Cartesian
product."""
if _is_numpy_shape(sh2):
if isinstance(sh1,(tuple,list,np.ndarray)):
return tuple(sh1) + tuple(sh2)
else:
assert isinstance(sh1,dict),"Erroneous hypershape object class "+sh1.__class__.__name__
res = {}
for (k,v) in sh1.items():
res[k] = _concat_hyper_shape(v,sh2)
return res
else:
if isinstance(sh1,(tuple,list,np.ndarray)):
if len(sh1) == 0:
return sh2
elif len(sh1) == 1:
return dict((i,sh2) for i in range(sh1[0]))
else:
raise NotImplementedError("TODO: cartesian product of multidimensional hyper-shapes")
else:
assert isinstance(sh1,dict),"Erroneous hypershape object class "+sh1.__class__.__name__
res = {}
for (k,v) in sh1.items():
res[k] = _concat_hyper_shape(v,sh2)
return res
def _unravel(obj,shape):
"""The reverse of _ravel:
- obj is a 1-D Numpy array (containing a sufficient number of items that can be extracted)
- shape is a Numpy shape or a hyper-shape.
Returns a nested structure of numpy array or lists in the hyper-shape specified by shape.
"""
assert isinstance(obj,np.ndarray)
olen = obj.shape[0]
if isinstance(shape,dict):
res = [None]*len(shape)
ofs = 0
for (i,s) in shape.items():
res[i] = _unravel(obj[ofs:],s)
ofs += _hyper_count(res[i])
if ofs > olen: raise ValueError("Insufficent sized vector to extract the desired parameters")
return res
else:
if len(shape) == 0:
if olen == 0: raise ValueError("Insufficent sized vector to extract the desired parameters")
return obj[0]
arrlen = np.product(shape)
if arrlen > olen: raise ValueError("Insufficent sized vector to extract the desired parameters")
return obj[:arrlen].reshape(shape)
def _addDictPrefix(prefix,items,prefixedKeys=None):
"""Recursively adds the prefix prefix to the keys in items, and any sub-dicts."""
if not isinstance(items,dict): return items
newdict = dict()
for (k,v) in items.items():
if prefixedKeys is None or k in prefixedKeys:
if prefix+k in newdict:
raise RuntimeError("Adding dictionary prefix %s created a name clash with key %s"%(str(prefix),str(k)))
newdict[prefix+k] = _addDictPrefix(prefix,v,prefixedKeys)
else:
newdict[k] = _addDictPrefix(prefix,v,prefixedKeys)
return newdict
[docs]class Context:
"""Base class for all symbolic operations. Define variables, expressions, and user data here.
"""
def __init__(self):
self.variables = []
self.variableDict = dict()
self.userData = dict()
self.expressions = dict()
self.customFunctions = dict()
[docs] def addVar(self,name,type='V',size=None):
"""Creates a new Variable with the given name, type, and size.
Valid types include:
- V: vector (default)
- M: matrix
- N: numeric (generic float or integer)
- B: boolean
- I: integer
- A: generic array
- X: index
size is a hint for vector variables to help determine how many entries
the variable should contain. It can be None, in which case the size is
assumed unknown.
"""
if name in self.variableDict:
raise RuntimeError("Variable "+name+" already exists")
if name in self.userData:
raise RuntimeError("Variable "+name+" conflicts with a user data")
if name in self.expressions:
raise RuntimeError("Variable "+name+" conflicts with an expression")
v = Variable(name,Type(type,size),self)
self.variables.append(v)
self.variableDict[name] = v
return v
[docs] def addVars(self,prefix,N,type='N',size=None):
"""Creates N new Variables with the names prefix+str(i) for i in 1,...,N.
type is the same as in addVar, but by default it is numeric.
"""
res = []
for i in range(N):
res.append(self.addVar(prefix+str(i+1),type,size))
return res
[docs] def addUserData(self,name,value):
"""Adds a new item to user data. These are un-serializable objects that can
be referred to by functions, but must be restored after loading."""
if name in self.variableDict:
raise RuntimeError("User data "+name+" conflicts with a variable")
if name in self.expressions:
raise RuntimeError("User data "+name+" conflicts with an expression")
if name in self.customFunctions or name in _builtin_functions:
raise RuntimeError("User data "+name+" conflicts with a function")
self.userData[name] = value
return UserDataExpression(name)
[docs] def addExpr(self,name,expr):
"""Declares a named expression"""
if name in self.variableDict:
raise RuntimeError("Expression "+name+" conflicts with a variable")
if name in self.expressions:
raise RuntimeError("Expression "+name+" already exists")
if name in self.customFunctions or name in _builtin_functions:
raise RuntimeError("Expression "+name+" conflicts with a function")
assert isinstance(name,str)
assert isinstance(expr,Expression)
self.expressions[name] = expr
return expr
[docs] def declare(self,func,fname=None,fargs=None):
"""Declares a custom function. If fname is provided, this will be how
the function will be referenced. Otherwise, the Python name of the
function is used.
Args:
func (function or :class:`Expression`): the function / expression
object
fname (str, optional): the name of the new Function. If func is an
Expression this is mandatory. Otherwise it is taken from the
Python function name.
fargs (list of str, optional): the argument names. If func is an
:class:`Expression` this is mandatory.
To convert Expressions into Functions, the fargs list declares the
order in which variables in func will be bound to arguments. E.g.,
``Context.declare(2*expr("x")*expr("y"),"twoxy",["x","y"])``
produces a two argument function ``twoxy(x,y): 2*x*y``.
"""
if isinstance(func,Function):
if fname is None:
fname = func.name
import copy
newfunc = copy.copy(func)
newfunc.name = fname
self.customFunctions[fname] = newfunc
return newfunc
if isinstance(func,Expression):
assert fname != None,"When declaring an expression as a function, the name must be provided"
assert fargs != None,"When declaring an expression as a function, the argument order must be provided"
self.customFunctions[fname] = Function(fname,func,fargs)
return self.customFunctions[fname]
if callable(func):
if fname is None:
import inspect
for x in inspect.getmembers(func):
if x[0] == '__name__':
fname = x[1]
break
assert fname != None
self.customFunctions[fname] = Function(fname,func,fargs)
return self.customFunctions[fname]
else:
raise ValueError("func should be a symbolic Function, symbolic Expression, or Python function")
[docs] def include(self,context,prefix=None,modify=False):
"""Adds all items inside another context as a sub-context.
If prefix != None, ``self.[prefix]`` is set to ``context``, and all
names in that context are prepended by the given prefix and a '.'.
If ``modify=True``, then all the Variable and Function names inside
the given sub-context are modified to fit the `self` context. This is
useful if you are saving/loading expressions and using the convenience
form self.[prefix].[functionName] to instantiate Functions embedded
inside the sub-context.
.. note::
When using modify=True, be careful not to re-use the sub-context
inside multiple super-contexts.
"""
if prefix is not None:
assert not hasattr(self,prefix),"Can't include a new context, already have an element named "+prefix
setattr(self,prefix,context)
prefix = prefix + '.'
if prefix is None:
modify = False
for v in context.variables:
n = v.name if prefix is None else prefix+v.name
self.addVar(n,v.type)
if modify:
v.name = n
if modify:
context.variableDict = dict((v.name,v) for v in context.variables)
def emodify(e):
if isinstance(e,VariableExpression):
return VariableExpression(self.variableDict[prefix + e.var.name])
elif isinstance(e,UserDataExpression):
return UserDataExpression(prefix +e.var.name)
elif isinstance(e,OperatorExpression):
return OperatorExpression(e.functionInfo,[emodify(a) for a in e.args],e.op)
return e
for n,e in context.expressions.items():
n = n if prefix is None else prefix+n
if prefix is None:
self.addExpr(n,e)
else:
#need to modify any user data or variables
self.addExpr(n,emodify(e))
for n,d in context.userData.items():
n = n if prefix is None else prefix+n
self.addUserData(n,d)
newfunctions = []
for n,d in context.customFunctions.items():
if modify:
fnew = d
else:
fnew = copy.copy(d)
fnew.simplifierDict = _addDictPrefix(prefix,fnew.simplifierDict,context.customFunctions)
fnew.presimplifierDict = _addDictPrefix(prefix,fnew.presimplifierDict,context.customFunctions)
newfunctions.append(fnew)
for f in newfunctions:
f.name = prefix + f.name
self.declare(f)
if modify:
context.variableDict = dict((v.name,v) for v in context.variables)
context.userData = dict((prefix+n,d) for n,d in context.userData.items())
context.expressions = dict((prefix+n,e) for n,e in context.expressions.items())
context.customFunctions = dict((d.name,d) for d in context.customFunctions.values())
[docs] def copy(self,copyFunctions=False):
"""Returns a shallow copy of this context."""
res = Context()
for v in self.variables:
res.variables.append(Variable(v.name,v.type,res))
for v in res.variables:
res.variableDict[v.name] = v
res.userData = self.userData.copy()
#deep-copy expressions so they refer to expressions and variables in the new context
def rebind(expr,res):
if isinstance(expr,OperatorExpression):
newargs = [rebind(a,res) for a in expr.args]
return OperatorExpression(expr.functionInfo,newargs,expr.op)
elif isinstance(expr,VariableExpression):
return VariableExpression(res.variableDict[expr.var.name])
else:
return expr
for n,e in self.expressions.items():
res.expressions[n] = rebind(e,res)
if copyFunctions:
res.customFunctions = self.customFunctions.copy()
else:
res.customFunctions = self.customFunctions
return res
[docs] def bind(self,vars,values):
"""Assigns multiple variables to values"""
for v,val in zip(vars,values):
if isinstance(v,str):
self.variableDict[v].value = val
else:
assert isinstance(v,Variable)
v.value = val
[docs] def bindFunction(self,function,remapping=None):
"""Produces an Expression that evalutes the function, where its
arguments are bound to variables / user data in the current
environment. The argument names should map to similarly named
variables or user data.
If remapping is provided, then it maps function arguments to variables
or values, either as
- a dictionary mapping a function argument arg to remapping[arg].
- a list or tuple mapping the i'th function argument to remapping[i].
"""
if isinstance(function,str):
function = self.customFunctions[function]
assert isinstance(function,Function)
args = []
for aindex,arg in enumerate(function.argNames):
if remapping is not None:
if isinstance(remapping,dict):
if arg in remapping:
var = remapping[arg]
if isinstance(var,str) and var in self.variableDict:
#Do we want to map it to the corresponding variable? Or just keep it as a userData reference?
pass
args.append(var)
continue
else:
#its a list
var = remapping[aindex]
if isinstance(var,str) and var in self.variableDict:
#Do we want to map it to the corresponding variable? Or just keep it as a userData reference?
pass
args.append(var)
continue
if arg in self.variableDict:
args.append(self.variableDict[arg])
elif arg in self.userData:
args.append(arg)
else:
raise ValueError("Function %s argument %s does not exist in the current context"%(function.name,arg))
return function(*args)
[docs] def listVars(self,doprint=True,indent=0):
if doprint:
for v in self.variables:
print(' '*indent + "%s (type %s)"%(v.name,str(v.type)))
return [v.name for v in self.variables]
[docs] def listExprs(self,doprint=True,indent=0):
if doprint:
for n,f in self.expressions.items():
print(' '*indent + n + '=' + str(f))
return list(self.expressions.keys())
[docs] def listFunctions(self,doprint=True,indent=0,builtins=False):
global _builtin_functions
if doprint:
for n,f in _builtin_functions.items():
argstr = '...' if f.argNames is None else ','.join(f.argNames)
print(' '*indent + n + '(' + argstr + ')')
for n,f in self.customFunctions.items():
argstr = '...' if f.argNames is None else ','.join(f.argNames)
print(' '*indent + n + '(' + argstr + ')')
if builtins:
return list(self.customFunctions.keys()) + list(_builtin_functions.keys())
return list(self.customFunctions.keys())
[docs] def expr(self,name):
"""Retrieves a named expression"""
try:
return self.expressions[name]
except KeyError:
pass
raise KeyError("Invalid expression name "+name)
[docs] def function(self,name):
"""Retrieves a named function.."""
global _builtin_functions
try:
return _builtin_functions[name]
except KeyError:
pass
try:
return self.customFunctions[name]
except KeyError:
pass
raise KeyError("Invalid function key "+name)
[docs] def get(self,name,*args):
"""Retrieves a named reference to a userData or variable.
Can be called as ``get(name)``, in which case a KeyError is raised if
the key does not exist, or ``get(name,defaultValue)`` in which case the
default value is returned if the key does not exist.
"""
try:
return self.userData[name]
except KeyError:
pass
try:
return self.variableDict[name]
except KeyError:
pass
if len(args) == 0:
raise KeyError("Invalid variable / userData key "+name)
return args[0]
[docs] def renameVar(self,item,newname):
"""Renames a variable"""
if isinstance(item,str):
item = self.variableDict[item]
if item.name == newname: return
assert newname not in self.variableDict,"Renamed variable name "+newname+" already exists"
del self.variableDict[item.name]
self.variableDict[newname] = item
item.name = newname
[docs] def renameUserData(self,itemname,newname):
"""Renames a userData"""
assert itemname in self.userData,"Userdata "+itemname+" does not exist"
if itemname == newname: return
assert newname not in self.userData,"Renamed userdata name "+newname+" already exists"
self.userData[newname] = self.userData[itemname]
del self.userData[itemname]
[docs] def makePyFunction(self,expr,varorder=None):
"""Converts an Expression or Function to a Python function ``f(x)``
that takes ``x`` as a list of scalar values, maps those to Variable
values, and returns the result of evaluating the expression / function.
Args:
expr (:class:`Function` or :class:`Expression`): the function or
expression to evaluate
varorder (list, optional): If given, the list of Variables that
should appear in the flattened argument list ``x``.
If this isn't provided, then the Variables in ``expr`` are
ordered by the order in which were added to this ``Context``.
Returns:
tuple: A pair ``(f,varorder)``, where:
* ``f(x)`` is a 1-argument Python function equivalent to ``expr``
but where ``x`` is a list of variable values.
* ``varorder`` gives the order of variables that should be
sent in ``x``
"""
if isinstance(expr,Function):
if varorder is None:
varnames = expr.argNames
varorder = [self.variableDict[v] for v in varnames]
iorder = list(range(len(varorder)))
else:
varorder = [(self.variableDict[v] if isinstance(v,str) else v) for v in varorder]
iorder = [expr.argNames.index(v.name) for v in varorder]
varnames = [v.name for v in varorder]
def f(*args):
eargs = expr.argNames[:]
for i,a in zip(iorder,args):
eargs[i] = a
#print("Arguments",eargs)
return expr(*eargs).evalf(self)
return f,varorder
assert isinstance(expr,Expression)
res = expr.eval(self)
if isinstance(res,Expression):
rvars = res.vars(self,bound=False)
if varorder is None:
allorder = dict()
for i,v in enumerate(self.variables):
allorder[v.name] = i
varorder = sorted(rvars,key=lambda v:allorder[v.name])
else:
#convert strings to Variables
varorder = [(self.variableDict[v] if isinstance(v,str) else v)for v in varorder]
vnames = set(v.name for v in varorder)
for v in rvars:
if v.name not in vnames:
warnings.warn("Error while creating Python function corresponding to",res)
raise ValueError("Unbound variable "+v.name+" not in given variable order "+",".join([var.name for var in varorder]))
def f(*args):
#print("Evaluating with order",[str(v) for v in varorder],args)
for (v,val) in zip(varorder,args):
v.bind(val)
res = expr.evalf(self)
for (v,val) in zip(varorder,args):
v.unbind()
return res
return (f,varorder)
else:
if varorder is None:
varorder = []
return ((lambda *args: res),varorder)
[docs] def makeFlatFunction(self,expr,varorder=None,defaultsize=1):
"""Given an expression expr, return ``(f,varorder)``,
where ``f`` is an equivalent 1-argument Python function that takes a
list of numbers or numpy array. The order of variables that should
be provided to ``f`` in this tuple is returned in ``varorder``.
If vector variables are not given size hints, then they are assumed
to have ``defaultsize``.
See also :meth:`makePyFunction`.
"""
f,varorder = self.makePyFunction(expr,varorder)
indices = self.getFlatVarRanges(varorder,defaultsize)
def fv(x):
arglist = []
for i,v in enumerate(varorder):
if v.type.char in 'AM':
arglist.append(x[indices[i]:indices[i+1]].reshape(v.size))
elif v.type.char == 'V':
arglist.append(x[indices[i]:indices[i+1]])
else:
arglist.append(x[indices[i]])
#print("makeFlatFunction lambda function: Extracted arguments",arglist,"from",x)
#print("Evaluating expression",expr)
return f(*arglist)
return fv,varorder
[docs] def makeFlatFunctionDeriv(self,expr,varorder=None,defaultsize=1):
"""Given a differentiable expression expr, return ``(df,varorder)``,
where df is an 1-argument Python function that that takes a list of
numbers or numpy array and outputs the derivative (Jacobian matrix)
of expression ``expr``.
The order of variables that should be provided to ``df`` in this tuple
is returned in ``varorder``.
If ``expr`` is not differentiable, then ``df=None`` is returned.
If vector variables are not given size hints, then they are assumed to
have ``defaultsize``.
See also :meth:`makeFlatFunction`.
"""
if varorder is None:
f,varorder = self.makePyFunction(expr)
dvs = []
for v in varorder:
dv = expr.deriv(v)
if dv is None:
warnings.warn("makeFlatFunctionDeriv: Derivative with respect to {} is undefined, returning None".format(v.name))
return None,varorder
if not isinstance(dv,Expression):
if _is_exactly(dv,0) and not v.type.is_scalar():
#get the jacobian dimensions right
assert v.type.char == 'V',"Can't make flat function of non-numeric or non-vector types: "+str(v.type)
dv = zero(count.optimized(v))
else:
dv = ConstantExpression(dv)
dvs.append(dv)
if len(dvs) > 1:
rt = expr.returnType()
if rt is None or rt.char is None:
warnings.warn("makeFlatFunctionDeriv: no return type inferred, assuming vector")
rt = Type('V')
if rt.is_scalar():
dflist,varorder = self.makePyFunction(flatten(*dvs),varorder)
df = lambda *args: np.array(dflist(*args))
elif rt.char == 'V':
df,varorder = self.makePyFunction(column_stack(*dvs),varorder)
else:
warnings.warn("Expression {} has return type {}".format(expr,rt))
raise NotImplementedError("Can't make flat function of non-numeric or non-vector types: "+str(rt))
else:
df,varorder = self.makePyFunction(dvs[0],varorder)
indices = self.getFlatVarRanges(varorder,defaultsize)
def dfv(x):
arglist = []
for i,v in enumerate(varorder):
if v.type.char in 'AM':
arglist.append(x[indices[i]:indices[i+1]].reshape(v.size))
elif v.type.char == 'V':
arglist.append(x[indices[i]:indices[i+1]])
else:
arglist.append(x[indices[i]])
#print("Extracted arguments",arglist)
return df(*arglist)
return dfv,varorder
[docs] def getFlatVarRanges(self,varorder,defaultsize=1):
"""If varorder is a variable order returned by makeFlatFunction, returns the list of index ranges
r=[k0,...,kn] where variable index i corresponds with x[r[i]:r[i+1]]. The x vector must have size
r[-1]."""
varorder = [(self.variableDict[v] if isinstance(v,str) else v) for v in varorder]
indices = [0]
for v in varorder:
assert v.type.char in 'AMV' or v.type.is_scalar(),"Can't handle list or user data types in makeFlatFunction"
for v in varorder:
if v.type.char in 'AM':
assert v.type.size is not None,"Matrix and array variables need to have their size specified"
indices.append(indices[-1]+v.type.len())
elif v.type.char == 'V':
size = v.type.size if v.type.size != None else defaultsize
indices.append(indices[-1]+size)
else:
indices.append(indices[-1]+1)
return indices
[docs] def getFlatVector(self,varorder):
"""If varorder is a variable order returned by makeFlatFunction, returns the x vector corresponding
to the current variable values."""
assert not any(v.value is None for v in varorder),"All variables' values must be set"
return np.hstack([v.value.flatten() if isinstance(v,np.ndarray) else v for v in varorder])
[docs]class Function(object):
"""A symbolic function. Contains optional specifications of argument and
return types, as well as derivatives.
Args:
name (str):
func (Python function or :class:`Expression`):
argNames (list of str, optional):
returnType (str or function, optional): sets ``self.returnType`` or
``self.returnTypeFunc``.
If ``func`` is a Python function, the argument list will be derived from
the declaration of the Python function. Otherwise, it must be an
:class:`Expression`, and ``argNames`` needs to be provided. The
expression must be closed, meaning that all unspecified variables in the
expression must be named in ``argNames``.
Examples:
Basic instantiation::
#standard function method
def f(x,y):
...
symfunc = Function("f",f)
a = context.addVar("a")
b = context.addVar("b")
c = context.addVar("c")
print(symfunc(a,b)) # prints f($a,$b)
Creating a Function from an Expression (sort of like a lambda function)::
expr = a + 2*b
symfunc2 = Function("f2", expr, ["a","b"]) #when symfunc2 is called, its first argument will be bound to a, and its second will be bound to b
print(symfunc2(c,4)) #prints f2(c,4)
print(symfunc2(c,4).eval({'c':5})) #prints 13, because c + 2*4 = 5 + 2*4 = 13
If ``func`` is an Expression, then you can automatically set derivatives
if all the sub-expressions have derivatives. To do so, use the
:meth:`autoSetJacobians` method.
Defining derivatives:
There are three ways to define derivatives: :meth:`setDeriv`,
:meth:`setJacobian`, and :meth:`autoSetJacobians` (available if
``func`` is an expression. In :meth:`setDeriv` you provide
Jacobian-vector products. In :meth:`setJacobian` you provide
Jacobian matrices (or tensors).
Attributes:
name (str): name of function used in printing and IO.
description (str, optional): text description of function.
func (Python function or :class:`Expression`): the function to be
evaluated.
argNames (list of strs, optional): names of arguments.
argTypes (list of Type, optional): list giving each argument's
:class:`Type`.
argDescriptions (list of strs, optional): strings describing each
argument
returnType (Type, optional): return :class:`Type`
returnTypeFunc (function, optional): a function that takes argument
types and produces a more specific return :class:`Type` than
``returnType``
returnTypeDescription (str, optional): description of the return type
deriv (optional): Can be either:
1. a list of Jacobian-vector products with respect to a derivative
of each argument: ``[df1(*args,darg1),...,dfn(*args,dargn)]``.
Here, ``dargi`` is an argument of the same shape as
``args[i]``, giving the derivative :math:`\\frac{dargs[i]}{dx}`
w.r.t. some scalar parameter x. The function ``dfi`` gives
:math:`\\frac{df}{dargs[i]}\\cdot\\frac{dargs[i]}{dx}` where
:math:`\\frac{df}{dargs[i]}` is the partial
of ``self`` with respect to the i'th argument.
2. a function ``df(args,dargs)`` giving the total derivative of f
given directional derivatives of each of the argmuments.
Here, ``dargs`` is a list of derivatives ``[darg1,...,dargn]``.
The result should be equal to
:math:`\\frac{df}{dx} = \\frac{df}{dargs[0]}\\cdot \\frac{dargs[0]}{dx} + ... + \\frac{df}{dargs[n-1]}\\cdot\\frac{dargs[n-1]}{dx}`.
The return value should have the same shape as
``self.returnType()``
colstackderiv (optional): same as ``deriv``, except that each function
accepts stacked argument derivatives. This can be more efficient
than filling in ``deriv``, since taking derivatives w.r.t. k
variables x1,...,xk can be done by setting up ``dargi`` as a matrix
``[dflatten(argi)/dx1 | ... | dflatten(argi)/dxk]``, i.e.,. the
column stacking of each of the flattened argument derivatives.
`dargi` has shape ``(count(argi),k)``.
rowstackderiv (optional): same as ``colstackderiv``, except that row-
wise stacked argument derivatives are accepted. In other words,
``dargi`` is a matrix
:math:`[\\frac{d\\text{flatten(args[i])}}{dx1}, ... , \\frac{d\\text{flatten(args[i])}}{dxk}]`
which has shape ``(k,count(argi))``. This
is more efficient than ``colstackderiv`` but corresponds less
directly to standard mathematical notation.
jacobian (list of functions, optional): list of Jacobian functions
with respect to each argument. Has the form
``[Jf1(*args),...,Jfn(*args)]`` where `Jfi` returns a matrix of
shape ``(count(self),count(args[i]))``. The total derivative with
respect to some variable ``x`` is
:math:`\\frac{df}{dx} =` ``reshape(dot(Jf1,flatten(darg1)) + dot(Jfn,flatten(dargn)),shape(self))``
presimplifier (function, optional):
simplifier (function, optional):
presimplifierDict (dict, optional): a nested dict, mapping argument
signatures to simplification functions. See :meth:`addSimplifier`
for more details.
simplifierDict (dict, optional): a nested dict, mapping argument
signatures to simplification functions. See :meth:`addSimplifier`
for more details.
properties (dict): possible properties of this function, such as
associativity, etc.
printers (dict of functions): a dict containing code generation
methods. Each entry, if present, is a function
``printer(expr,argstrs)`` that returns a string. Common key values
are:
* "str": for printing to console
* "parse": for parseCompatible=True exprToStr
"""
def __init__(self,name,func,argNames=None,returnType=None):
self.name = name
self.description = None
self.func = func
self.argNames = argNames
self.argTypes = None
self.argDescriptions = None
self.returnType = None
self.returnTypeFunc = None
self.returnTypeDescription = None
self.setReturnType(returnType)
if callable(func) and argNames is None:
import inspect
(argNames,varargs,keywords,defaults) = inspect.getargspec(func)
if varargs != None or keywords != None:
#variable number of arguments
warnings.warn("symbolic.py: Note that Function {} is declared with variable arguments, pass '...' as argNames to suppress this warning".format(name))
pass
else:
self.argNames = argNames
if self.argNames == '...':
self.argNames = None
if isinstance(func,Expression):
if returnType is None:
self.returnType = func.returnType()
if self.argNames is None:
raise ValueError("Cannot use Expression as a function with variable arguments")
self.exprArgRefs = []
for arg in self.argNames:
uexpr = func.find(UserDataExpression(arg))
if uexpr is None:
vexpr = func.find(VariableExpression(Variable(arg,'V')))
if vexpr is None:
raise ValueError("Expression-based function %s does not contain specified argument %s"%(str(func),arg))
else:
self.exprArgRefs.append(vexpr)
assert vexpr.var.name == arg
else:
self.exprArgRefs.append(uexpr)
assert uexpr.name == arg
self.argTypes = [e.returnType() if not isinstance(e,UserDataExpression) else Type(None) for e in self.exprArgRefs]
self.deriv = None
if isinstance(self.returnType,Type) and self.returnType.char is not None and self.returnType.char in "BI":
#no derivatives for boolean or integer functions
self.deriv = 0
self.colstackderiv = None
self.rowstackderiv = None
self.jacobian = None
self.presimplifier = None
self.simplifier = None
self.presimplifierDict = dict()
self.simplifierDict = dict()
self.properties = dict()
self.printers = dict()
def __getattribute__(self,name):
if name=='__doc__':
return self.info()
else:
return object.__getattribute__(self, name)
def __call__(self,*args):
if self.argNames is not None:
if len(args) != len(self.argNames):
raise ValueError("Invalid number of arguments passed to "+self.name)
if TYPE_CHECKING and self.argTypes is not None:
#print("Calling",self.name,"(",','.join(str(a) for a in args),")")
for i,(a,t) in enumerate(zip(args,self.argTypes)):
#print(" Type checking argument",a,"against type",t)
#print(" ... argument has type",type_of(a))
#raw_input()
if t is None: continue
a = expr(a)
if not t.match(type_of(a)):
raise ValueError("Invalid argument %d (%s) passed to %s: type %s doesn't match %s"%(i,str(simplify(a)),self.name,type_of(a),t))
#rint " match."
#raw_input()
return self._call(*args)
def _call(self,*args):
"""Internally used version of __call__ -- does not perform type checking"""
if callable(self.func):
return OperatorExpression(self,args)
else:
assert isinstance(self.func,Expression),"What..."+self.func.__class__.__name__
assert self.argNames is not None
assert len(self.exprArgRefs) == len(args)
return OperatorExpression(self,args,lambda *argvals:subs(self.func,self.exprArgRefs,argvals).eval())
[docs] def optimized(self,*args):
"""Similar to ``self(*args).simplify(depth=1)``, but optimized to
reduce complexity earlier. In particular, this directly applies the
operation if all arguments are constants, and it will try to apply
the simplifier immediately."""
if self.presimplifier != None:
res = self.presimplifier(*args)
if res is not None:
if isinstance(res,Expression):
return res.simplify(None,depth=1)
else:
return res
anyVariable = False
newargs = []
if self.argNames is not None:
assert len(args) == len(self.argNames),"Invalid number of arguments passed to "+self.name
for i,arg in enumerate(args):
ac = to_const(arg,shallow=True)
if ac is None:
anyVariable = True
newargs.append(arg)
else:
newargs.append(ac)
if not anyVariable:
#just call the function
if callable(self.func):
try:
return ConstantExpression(self.func(*newargs))
except Exception as e:
#print("Error evaluating expression",self.name,"arguments",newargs)
raise
else:
#return self.func.eval(context=dict(self.argNames,newargs))
return subs(self.func,self.exprArgRefs,newargs).eval()
else:
res = self.__call__(*newargs)
res2 = res._simplify(depth=1,constant_expansion=False)
if res2 is not None:
return res2
return res
[docs] def checkArg(self,arg):
"""Verifies that a named or indexed argument is valid, and normalizes it. Returns an (index,name) tuple"""
if isinstance(arg,str):
if self.argNames is None:
raise ValueError("Can't get an argument by name for a variable-argument function")
return self.argNames.index(arg),arg
if self.argNames is None:
if arg < 0:
raise ValueError("Negative argument index specified")
return arg,'arg_'+str(arg)
if arg < 0 or arg >= len(self.argNames):
raise ValueError("Invalid argument specified")
return arg,self.argNames[arg]
[docs] def setDeriv(self,arg,dfunc,asExpr=False,stackable=False):
"""Declares a (partial) derivative of the function with respect to
argument arg. The function ``dfunc`` is a function that takes
arguments ``(arg1,...,argn,dx)``, where ``dx`` is the
derivative of ``arg`` with respect to some variable x, and
returns :math:`df/darg(arg1,...,argn) * darg/dx`.
For vector-valued arguments and functions, the * is a matrix-vector
product, and ``dfunc`` is required to produce what's commonly known as
the Jacobian-vector product.
For matrix-valued arguments or functions, the * is a tensor product.
Args:
arg (int or str): either an argument index or name.
dfunc: either:
1. an Expression of variables arg1,...,argn,darg,
2. a ``Function`` of n+1 variables mapping to ``arg1,...,argn,dx``,
3. a Python function of n+1 variables ``(arg1,...,argn,dx)``
that either returns a value (setting ``asExpr=False``) or
an ``Expression`` of the derivative (if ``asExpr=True``).
4. None, to indicate that the derivative is not defined,
5. 0, to indicate that the derivative is identically 0.
In the normal case (``stackable=False``), the shape of ``dx``
is the same as ``arg``, and the result has the same shape as
self.
asExpr (bool, optional): if ``dfunc`` is a Python function, this
flag says that it will return an ``Expression``. Using an
``Expression`` allows taking multiple derivatives.
stackable (bool or str, optional): states whether ``dfunc`` can
accept stacked derivative arguments. If ``True`` or ``'col'``,
then it can accept ``dvar`` as a column-stacked array of
derivatives. In other words, ``dvar[...,0]`` is the derivative
w.r.t. x0, ``dvar[...,1]`` is the derivative w.r.t. x1, etc.
E.g., for vector arguments, dvar can be thought of as a matrix
with k=``dvar.shape[-1]`` derivatives in its *columns*.
In this case, ``dfunc`` should return an array of shape
shape(self) x k (i.e., f's derivatives are stacked in columns.)
``stackable`` can also be ``'row'`` in which case ``dfunc``
accepts ``dvar`` as a row-stacked array of derivatives. I.e,
``dvar[0,...]`` is the derivative w.r.t. x0, ``dvar[1,...]`` is
the derivative w.r.t., x1, etc. In this case, ``dfunc`` should
return a list of derivatives.
If the argument is not an array or its type cannot be
determined (see ``setArgType``) then ``dvar`` will be passed as
a ``count(arg)`` x k matrix (for column-stacking) or a
k x ``count(arg)`` matrix (for row stacking).
If the function does not return an array or the return type
cannot be determined (see ``setReturnType``), the shape of the
returned matrix must be ``count(self)`` x k (for column-
stacking) or k x ``count(self)`` (for row-stacking).
"""
aindex,arg = self.checkArg(arg)
if self.deriv is None:
self.deriv = [None]*len(self.argNames)
if isinstance(dfunc,Expression):
self.deriv[aindex] = lambda *args:subs(dfunc,self.exprArgRefs+['d'+arg],args)
elif isinstance(dfunc,Function):
self.deriv[aindex] = dfunc
elif dfunc is None or _is_exactly(dfunc,0):
self.deriv[aindex] = dfunc
elif asExpr:
assert callable(dfunc)
#print("Setting derivative function with name",self.name,"argument",arg,"as expression")
self.deriv[aindex] = dfunc
else:
assert callable(dfunc)
#print("Declaring new derivative function with name",self.name + "_deriv_" + arg)
temp_function_info = Function(self.name + "_deriv_" + arg,dfunc,self.argNames + [VAR_DERIV_PREFIX+arg])
if not hasattr(self,'deriv_funcs'):
self.deriv_funcs = [None]*len(self.argNames)
self.deriv_funcs[aindex] = temp_function_info
self.deriv[aindex] = temp_function_info #lambda *args:OperatorExpression(temp_function_info,args)
if stackable == True or stackable == 'col':
if self.colstackderiv is None:
self.colstackderiv = [None]*len(self.argNames)
self.colstackderiv[aindex] = self.deriv[aindex]
elif stackable == 'row':
if self.rowstackderiv is None:
self.rowstackderiv = [None]*len(self.argNames)
self.rowstackderiv[aindex] = self.deriv[aindex]
[docs] def setJacobian(self,arg,dfunc,asExpr=False):
"""Declares a (partial) derivative of the function with respect to
argument ``arg``.
This only makes sense when the arguments are all arrays or scalars; no
complex types are supported.
Args:
arg (int or str): same as ``setDeriv``.
dfunc: similar to ``setDeriv``, but a function that takes n
arguments (arg1,...,argn) and returns the matrix
:math:`df/darg(arg1,...,argn)`.
The return value of ``dfunc`` must have shape shape(self) + shape(arg),
so that:
- If self and arg are scalars, the Jacobian is also a scalar
- If self and arg are a scalar and a vector (or vice versa) the
Jacobian is a vector.
- If self and arg contain a scalar and an N-D array, the Jacobian
is an N-D array.
- If self and arg are both vectors, the Jacobian is a matrix.
"""
aindex,arg = self.checkArg(arg)
if TYPE_CHECKING:
self_irregular = isinstance(self.returnType,Type) and self.returnType.char not in SCALAR_TYPES+ARRAY_TYPES
var_irregular = self.argTypes is not None and isinstance(self.argTypes[aindex],Type) and self.argTypes[aindex].char not in SCALAR_TYPES+ARRAY_TYPES
if self_irregular:
warnings.warn("symbolic.autoSetJacobians: setting Jacobian of a non-array function. Will flatten.")
if var_irregular:
warnings.warn("symbolic.autoSetJacobians: setting Jacobian of a non-array argument. Will flatten.")
if self.deriv is None:
self.deriv = [None]*len(self.argNames)
if self.jacobian is None:
self.jacobian = [None]*len(self.argNames)
if isinstance(dfunc,Expression):
self.jacobian[aindex] = lambda *args:subs(dfunc,self.exprArgRefs,args)
elif isinstance(dfunc,Function):
self.jacobian[aindex] = dfunc
elif dfunc is None or _is_exactly(dfunc,0):
self.jacobian[aindex] = dfunc
elif asExpr:
if not callable(dfunc):
raise TypeError("Unexpected type of dfunc %s"%(dfunc.__class__.__name__))
#print "Setting jacobian function with name",self.name,"argument",arg,"as expression"
self.jacobian[aindex] = dfunc
else:
if not callable(dfunc):
raise TypeError("Unexpected type of dfunc %s"%(dfunc.__class__.__name__))
#print "Declaring new derivative function with name",self.name + "_jacobian_" + arg
temp_function_info = Function(self.name + "_jacobian_" + arg,dfunc,self.argNames)
self.jacobian[aindex] = temp_function_info #lambda *args:OperatorExpression(temp_function_info,args)
[docs] def autoSetJacobians(self,args=None):
"""For ``Expression``-based functions, can automatically set the
Jacobians. If ``args`` is not None, only the arguments in ``args``
are set.
"""
if not isinstance(self.func,Expression):
raise ValueError("Can only auto-set jacobians for Expressions")
if args is None:
args = self.argNames
for arg in args:
darg = self.func.deriv(arg)
if _DEBUG_DERIVATIVES:
print("symbolic.autoSetJacobians: Derivative of",self.func,"w.r.t",arg,"is (type %s)"%(darg.__class__.__name__,),darg)
if darg is not None:
darg = expr(darg)
self.setJacobian(arg,darg,asExpr=True)
[docs] def setReturnType(self,type):
"""Sets a return type specifier.
Args:
type (:class:`Type`, str, or function): if a Type object, the return
type is set directly. If a character type specifier, it is cast
to Type. If it is a function, it is assumed to be a function
that that takes in argument Types and returns a :class:`Type`.
"""
if callable(type):
self.returnTypeFunc = type
else:
self.returnType = Type(type)
[docs] def setArgType(self,arg,type):
"""Sets an argument type specifier.
Args:
arg (int or str): an index or string naming an argument.
type (Type or str): a :class:`Type` object or character type
specifier for the specified argument.
"""
if self.argTypes is None:
self.argTypes = [None]*len(self.argNames)
index,name = self.checkArg(arg)
self.argTypes[index] = Type(type)
[docs] def getArgType(self,arg):
"""Retrieves an argument type.
Args:
arg (int or str): the index or string naming an argument.
"""
if self.argTypes is None: return None
index,name = self.checkArg(arg)
return self.argTypes[index]
[docs] def addSimplifier(self,signatures,func,pre=False):
"""For a signature tuple, sets the simplifier to ``func``.
Args:
signatures (list): a list of argument signatures. A signature can
be:
* the operation name for an OperatorExpression, passing the arg
directly to func(...)
* '_scalar': matches to a constant scalar, and passes that
constant to func(...)
* '_const': matches to constant, and passes that constant to
``func(...)``
* '_returnType': passes the argument's ``returnType()`` to
``func(...)``
* None: match all. The arg is passed directly to func(...)
func (callable): a callable that takes a list of function arguments,
possibly transformed by _const or _returnType. This returns a
simplified :class:`Expression` or ``None``.
pre (bool, optional): if True, the simplifier is called before
arguments are simplified.
If multiple matches are present then they are tested in order
1. operation name
2. _const
3. _returnType
4. None
"""
if self.argNames is not None:
if len(signatures) != len(self.argNames):
raise ValueError("Invalid signature length")
if len(signatures) == 0:
raise ValueError("Invalid signature tuple")
root = self.presimplifierDict if pre else self.simplifierDict
for s in signatures[:-1]:
if s is not None and not isinstance(s,str):
raise ValueError("Signatures must be strings or None")
root = root.setdefault(s,dict())
root[signatures[-1]] = func
[docs] def simplify(self,args,pre=False):
"""Performs simplification of ``OperatorExpression(self,args)``, either
with the simplifier function or the simplifierDict."""
if pre:
simplifier = self.presimplifier
simplifierDict = self.presimplifierDict
else:
simplifier = self.simplifier
simplifierDict = self.simplifierDict
for a in args:
assert isinstance(a,Expression),"Uh... args to simplify need to be Expressions, got an arg %s of type %s"%(str(a),a.__class__.__name__)
if simplifier is not None:
res = simplifier(*args)
if res is not None:
if SHAPE_CHECKING and isinstance(res,Expression):
try:
shres = res.returnType().shape()
if to_const(shres) is not None:
shself = self.__call__(*args).returnType().shape()
if to_const(shself) is None or to_const(shres) != to_const(shself):
warnings.warn("symbolic.simplify: simplified version of {} doesn't have same size: {} -> {}".format(self.functionInfo.name,shself,shres))
warnings.warn("Args: {}".format(args))
warnings.warn("Simplified {}".format(res))
except Exception:
pass
return res
if len(simplifierDict) > 0:
#print "Trying to match",[str(a) for a in args],"to simplifier dict"
root = simplifierDict
stack = []
passedArgs = args[:]
for i,a in enumerate(args):
#print "Keys:",root.keys()
#print "Arg",i,":",a
if isinstance(a,OperatorExpression) and a.functionInfo.name in root:
if _DEBUG_SIMPLIFY:
print("Function",self.name,"simplifier list matches arg",a.functionInfo.name)
stack.append(a.functionInfo.name)
root = root[a.functionInfo.name]
continue
if '_scalar' in root:
ac = to_scalar(a)
if ac is not None:
if _DEBUG_SIMPLIFY:
print("Function",self.name,"simplifier list matches _scalar for arg",a)
stack.append('_scalar')
passedArgs[i] = ac
root = root['_scalar']
continue
if '_const' in root:
ac = to_const(a)
if ac is not None:
if _DEBUG_SIMPLIFY:
print("Function",self.name,"simplifier list matches _const for arg",a)
stack.append('_const')
passedArgs[i] = ac
root = root['_const']
continue
if '_returnType' in root:
if _DEBUG_SIMPLIFY:
print("Function",self.name,"simplifier list matches arg",a,"returnType",a.returnType())
stack.append('_returnType')
passedArgs[i] = a.returnType()
root = root['_returnType']
continue
if None in root:
#print " Matches None"
if _DEBUG_SIMPLIFY:
print("Function",self.name,"simplifier list matches all for arg",a)
stack.append(None)
root = root[None]
continue
#no match
if _DEBUG_SIMPLIFY:
if isinstance(a,OperatorExpression):
print("Function",self.name,"simplifier list",list(root.keys()),"fails to match",a)
return None
if callable(root):
res = root(*passedArgs)
if _DEBUG_SIMPLIFY:
print("Function",self.name,"matches pattern",stack,", simplified to",res)
if SHAPE_CHECKING and isinstance(res,Expression):
try:
shres = res.returnType().shape()
if to_const(shres) is not None:
shself = self.__call__(*args).returnType().shape()
if to_const(shself) is None or to_const(shres) != to_const(shself):
warnings.warn("symbolic.simplify: simplified version of {} doesn't have same size: {} -> {}".format(self.name,shself,shres))
warnings.warn("Args: {}".format(passedArgs))
warnings.warn("Simplified {}".format(res))
except Exception:
pass
return res
return None
return None
[docs] def info(self):
"""Returns an text string describing the Function, similar to a docstring"""
argstr = '...' if self.argNames is None else ','.join(self.argNames)
signature = '%s(%s)'%(self.name,argstr)
if isinstance(self.func,Expression):
signature = signature + '\n Defined as '+str(self.func)
argHelp = None
if self.argTypes is not None or self.argDescriptions is not None:
if self.argNames is None:
argHelp = [str(self.argDescriptions)]
else:
argHelp= []
if self.argTypes is not None:
assert len(self.argNames) == len(self.argTypes),"invalid arg specification for function %s, %d args != %d types"%(str(self.func),len(self.argNames),len(self.argTypes))
for i,name in enumerate(self.argNames):
type = None if self.argTypes is None else self.argTypes[i]
desc = None if self.argDescriptions is None else self.argDescriptions[i]
if desc is None:
if type is not None:
desc = type.info()
else:
desc = 'unknown'
argHelp.append('- %s: %s'%(name,desc))
returnHelp = None
if self.returnTypeDescription is not None:
returnHelp = "- " + self.returnTypeDescription
elif self.returnType is not None:
returnHelp = "- " + self.returnType.info()
elif self.returnTypeFunc is not None:
if self.returnTypeFunc == _propagate_returnType:
returnHelp = "- same as arguments"
elif self.returnTypeFunc == _promote_returnType:
returnHelp = "- same as arguments"
elif self.returnTypeFunc == _returnType1:
returnHelp = "- same as argument 1"
elif self.returnTypeFunc == _returnType2:
returnHelp = "- same as argument 2"
elif self.returnTypeFunc == _returnType3:
returnHelp = "- same as argument 3"
else:
returnHelp = "- dynamic"
derivHelp = None
if self.deriv is not None or self.jacobian is not None:
derivHelp = []
if _is_exactly(self.deriv,0):
derivHelp.append('- derivative is 0 everywhere')
elif callable(self.deriv):
deval = None
if self.argNames is not None:
argTypes = [Type(None)]*len(self.argNames) if self.argTypes is None else self.argTypes
vars = [expr(Variable(a,t)) for a,t in zip(self.argNames,argTypes)]
dvars = [expr(Variable('d'+a,t)) for a,t in zip(self.argNames,argTypes)]
try:
deval = self.deriv(vars,dvars)
except Exception:
pass
if deval is not None:
derivHelp.append('- derivative is '+str(deval))
else:
derivHelp.append('- derivative is a total derivative function')
elif callable(self.jacobian):
derivHelp.append('- jacobian is a total derivative function')
elif self.argNames is not None:
argTypes = [Type(None)]*len(self.argNames) if self.argTypes is None else self.argTypes
vars = [expr(Variable(a,t)) for a,t in zip(self.argNames,argTypes)]
for i,a in enumerate(self.argNames):
if self.deriv is not None and self.deriv[i] is not None:
if _is_exactly(self.deriv[i],0):
derivHelp.append('- %s: derivative is 0'%(a,))
elif isinstance(self.deriv[i],Function):
derivHelp.append('- %s: available as Python df/da * da/dx function'%(a,))
else:
try:
deval = self.deriv[i](*(vars+[expr(Variable('d'+a,argTypes[i]))]))
if is_op(deval,'subs'):
deval = deval.args[0]
derivHelp.append('- %s: available as df/da * da/dx function %s'%(a,str(deval)))
except Exception as e:
derivHelp.append('- %s: available as df/da * da/dx function, Exception %s'%(a,str(e)))
elif self.jacobian is not None and self.jacobian[i] is not None:
if _is_exactly(self.jacobian[i],0):
derivHelp.append('- %s: jacobian is 0'%(a,))
elif isinstance(self.deriv[i],Function):
derivHelp.append('- %s: available as Python jacobian df/da'%(a,))
else:
try:
deval = self.jacobian[i](*vars)
if is_op(deval,'subs'):
deval = deval.args[0]
derivHelp.append('- %s: available as jacobian df/da %s'%(a,str(deval)))
except Exception as e:
derivHelp.append('- %s: available as jacobian df/da, Exception %s'%(a,str(e)))
items = [signature]
if self.description != None:
items += ['',self.description,'']
if argHelp is not None and len(argHelp) > 0:
items += ['','Parameters','---------']+argHelp
if returnHelp is not None:
items += ['','Return type','-----------',returnHelp]
if derivHelp is not None and len(derivHelp) > 0:
items += ['','Derivatives','-----------']+derivHelp
return '\n'.join(items)
[docs]class Variable:
def __init__(self,name,type,ctx=None):
self.name = name
self.type = Type(type)
self.ctx = None if ctx is None else weakref.proxy(ctx)
self.value = None
[docs] def isAssigned(self):
return (self.value is not None)
[docs] def bind(self,value):
self.value = value
[docs] def unbind(self):
self.value = None
[docs] def context(self):
return self.ctx
def __str__(self):
if self.value is None:
return self.name
else:
return self.name+'='+str(self.value)
def __len__(self):
if self.type.size == 0:
raise TypeError("Numeric Variable has no len()")
if self.type.size is None:
if self.value is not None:
return len(self.value)
raise TypeError("Variable-sized vector Variable has no len()")
return self.type.size
def __not__(self):
return not_(self)
def __and__(self,rhs):
return and_(self,rhs)
def __or__(self,rhs):
return or_(self,rhs)
def __eq__(self,rhs):
return eq(self,rhs)
def __ne__(self,rhs):
return ne(self,rhs)
def __le__(self,rhs):
return le(self,rhs)
def __ge__(self,rhs):
return ge(self,rhs)
def __neg__(self):
return neg(self)
def __add__(self,rhs):
return add(self,rhs)
def __sub__(self,rhs):
return sub(self,rhs)
def __mul__(self,rhs):
return mul(self,rhs)
def __div__(self,rhs):
return div(self,rhs)
def __truediv__(self,rhs):
return div(self,rhs)
def __pow__(self,rhs):
return pow_(self,rhs)
def __radd__(self,lhs):
return add(lhs,self)
def __rsub__(self,lhs):
return sub(lhs,self)
def __rmul__(self,lhs):
return mul(lhs,self)
def __rdiv__(self,lhs):
return div(lhs,self)
def __rtruediv__(self,lhs):
return div(lhs,self)
def __rpow__(self,lhs):
return pow_(lhs,self)
def __getitem__(self,index):
return getitem(self,index)
def __abs__(self):
return abs_(self)
def __setitem__(self,index,val):
return setitem(self,index,val)
def __getattr__(self,attr):
if attr == 'T':
return transpose(self)
elif attr == 'shape':
return shape(self)
else:
raise AttributeError("'Variable' object has no attribute '%s'" % attr)
#def __getslice__(self,*args):
# return _builtin_functions['getslice'](self,index)
class _TraverseException(Exception):
def __init__(self,e,node,task,exc_info):
self.e = e
self.exc_info = exc_info
self.path = [node]
self.parents = []
n = node
while n._parent is not None and not isinstance(n._parent,str):
self.parents.append(n._parent)
n = n._parent[0]()
self.path.append(n)
self.parents = self.parents[::-1]
self.path = self.path[::-1]
self.task = task
def __str__(self):
print(":", end=' ')
if isinstance(self.e,_TraverseException):
#lazy traversal
print("Expression traversal: Recursive exception during %s"%(str(self.task),))
elist = [self.e]
while isinstance(elist[-1].e,_TraverseException):
elist.append(elist[-1].e)
self.print_call_stack()
for i,e in enumerate(elist):
print(" ","@",e.task,"@:")
e.print_call_stack((i+1)*2)
print("Lowest level exception:", end=' ')
return elist[-1].e.__class__.__name__+": "+str(elist[-1].e)
else:
print("Expression traversal: Exception '%s' during %s"%(self.e,str(self.task)))
self.print_call_stack()
return ""
def __repr__(self):
return self.__str__()
def reraise(self):
raise self.e.with_traceback(self.exc_info)
def print_call_stack(self,indent=0):
self.path[0]._parent = self.task
for n,p in zip(self.path[1:],self.parents):
n._parent = p
self.path[-1]._print_call_stack(indent)
[docs]class Expression(object):
def __init__(self):
self._parent = None
self._id = None
self._cache = dict()
self._children = None
[docs] def returnType(self):
"""Returns the output type of this expression."""
raise NotImplementedError()
[docs] def depth(self,cache=True):
"""Returns the depth of the deepest leaf of this expression tree (i.e.,
its height). If cache = True, depths are cached as attribute 'depth' to
speed up subsequent depth() calls."""
try:
return self._cache['depth']
except KeyError:
pass
def _depth(node,childdepths):
if len(childdepths)==0: return (True,1)
return (True,max(childdepths)+1)
if not cache:
return self._traverse(post=_depth,cache=False)
else:
return self._traverse(post=_depth,cache=True,clearcache=False,cacheas='depth')
[docs] def isConstant(self):
"""Returns true if this expression is identically a constant. If so,
then the expression can be safely replaced with its evalf() value
without any possible change in meaning.
"""
raise NotImplementedError()
[docs] def returnConstant(self,context=None):
"""Returns true if the evaluation of the expression in the given
context results in a constant. If so, then :meth:`evalf` can safely be
applied without error."""
raise NotImplementedError()
[docs] def eval(self,context=None):
"""Evaluates the expression while substituting all constant values.
The result can be a constant or an Expression if the result is not
constant.
Args:
context (Context or dict, optional): a map of variable names to
constant values.
"""
return self._eval(context)
[docs] def evalf(self,context=None):
"""Evaluates the expression while substituting all constant values,
enforcing that the result is a constant. If the result is not
constant, a ValueError is returned.
"""
r = self.eval(context)
if not is_const(r):
from . import symbolic_io
raise ValueError("Expression.evalf: result of %s is not a constant, result is %s"%(str(symbolic_io.exprToJson(self)),str(r)))
if isinstance(r,ConstantExpression):
return r.value
if isinstance(r,Expression):
from . import symbolic_io
raise ValueError("Expression.evalf: result of %s is not a constant, result is %s"%(str(symbolic_io.exprToJson(self)),str(r)))
return r
def _eval(self,context):
"""Internally used eval"""
raise NotImplementedError()
def _evalf(self,context=None):
"""Internally used evalf. May be a tiny bit faster than :meth:`evalf`
but has less informative error messages.
"""
raise NotImplementedError()
[docs] def deriv(self,var,context=None):
"""Returns a new expression where all bound variables and variables in
context are reduced to constant values. If all arguments are constant,
returns a constant.
Args:
var: either a Variable, a string naming a Variable, or a dictionary
mapping variable names to derivatives. For example, in the
latter case, if this is an expression f(x,y),and var is a
dictionary {"x":dx,"y":dy}, then the result is
``df/dx(x,y)dx + df/dy(x,y)dy``
context (Context or dict, optional): a map of variable names to
constant values.
"""
return self._deriv(var,context)
def _deriv(self,var,context,rows):
"""Internally used deriv"""
return 0
def __str__(self):
"""Returns a string representing this expression"""
from . import symbolic_io
exp = symbolic_io.exprToStr(self,parseCompatible=False)
assert isinstance(exp,str),"Invalid type returned by exprToStr: "+exp.__class__.__name__
return exp
def __repr__(self):
from . import symbolic_io
return symbolic_io.exprToStr(self,parseCompatible=True)
[docs] def vars(self,context=None,bound=False):
"""Returns a list of all free Variables under this expression.
If bound is True, any Variables currently bound to values are
returned, otherwise they are not."""
return []
[docs] def match(self,val):
"""Returns true if self is equivalent to the given expression. Note: no
rearrangement is attempted, so x+y does not match y+x."""
if isinstance(val,Wildcard): return True
return False
[docs] def find(self,val):
"""Returns self if self contains the given expression as a sub-expression, or None otherwise."""
if self.match(val): return self
return None
[docs] def replace(self,val,valReplace,error_no_match=False):
"""Returns a copy of this expression where any sub-expression matching val gets replaced
by valReplace. If val doesn't exist as a subexpression, self is returned."""
if self.match(val):
return valReplace
if error_no_match:
raise ValueError("Unable to find term %s in expression %s"%(str(val),str(self)))
return self
[docs] def simplify(self,context=None,depth=None):
"""Returns a simplified version of self, or self if it cannot be
simplified.
Args:
context (dict or :class:`Context`): any Variable / UserData whose
name appears in here will be substituted with its value.
depth (int, optional): None for full-depth simplification, or the
max depth to explore.
Returns:
Expression:
"""
res = self._simplify(context,depth)
if res is None: return self
return res
def _simplify(self,context=None,depth=None):
"""Returns a new expression, or None if this cannot be simplified.
- context: a dict or Context for constant substitution
- depth: None for full-depth simplification, or the max depth to explore
"""
return None
def _traverse(self,pre=None,post=None,cache=True,clearcache=True,cacheas=None):
"""Generic traversal function.
Args:
pre (function, optional): a function f(expr) of an Expression that
is called before traversing children. It returns a triple
``(descend,cont,value)`` which controls the recursion:
* If descend is True, then this proceeds to traverse children
and call post. value is ignored.
* If descend is False, value is the return value of the
traversal.
* If cont is True, then the parent of expr continues traversing
children.
* If cont is False, then the parent of expr continues
traversing children.
For example, a find function would return (False,False,*) if an
item is found, and (True,True,*) otherwise.
post (function, optional): a two-argument function
``f(expr,childvals)`` called after traversing children.
``childvals`` are the return values of traversing children.
It returns a pair ``(cont,value)``, where value is the return
value of the traversal under ``expr``, and ``cont`` which
controls the recursion:
* If cont is True, the parent of expr continues traversing
children.
* If cont is False, the parent of expr stops traversing
children and value is used the return value of the parent.
If post is evaluated, the cont value here overrides the cont
value of pre.
cache (bool, optional): if True, uses the caching functionality.
clearcache (bool, optional): if True, deletes the cache after
traversal. Default uses cache and clears it.
cacheas (str, optional): if not None, caches the return value as an
attribute. For this to work, ``cache`` must be True.
Returns:
result: The value returned by pre or post.
"""
if cacheas is None:
cacheas = 0
if cache and clearcache:
if cacheas in self._cache:
print("Hmm... cache[%s] was not cleared correctly before traverse is called? Or perhaps reentrant traversal?"%(str(cacheas),))
print("Myself: %s, cached value %s"%(str(self),str(self._cache[cacheas])))
import traceback
traceback.print_stack()
cache = False
def _traverse_recurse_cache(node,pre,post):
try:
res = node._cache[cacheas]
if _DEBUG_TRAVERSE and cacheas in _DEBUG_TRAVERSE_ITEMS:
print("Found cached value cache[%s]=%s for node %s"%(cacheas,str(res),str(self)))
return (True,res)
except KeyError:
#print "Did not find cached value cache[%s]"%(cacheas,)
pass
value = None
cont = True
if pre is not None:
try:
(descend,cont,value) = pre(node)
except Exception as e:
import sys
return (False,_TraverseException(e,node,'pre-traverse '+cacheas,sys.exc_info()[2]))
if not descend:
if _DEBUG_TRAVERSE and cacheas in _DEBUG_TRAVERSE_ITEMS:
print("Not descending under %s, caching as %s and returning value"%(str(node),cacheas,value))
node._cache[cacheas] = value
return (cont,value)
cvals = []
if node._children is not None:
for i,c in enumerate(node._children):
try:
res = c._cache[cacheas]
if _DEBUG_TRAVERSE and cacheas in _DEBUG_TRAVERSE_ITEMS:
print("Found cached child %d cache[%s]=%s for node %s"%(i,cacheas,str(res),str(node)))
cvals.append(res)
except KeyError:
c._parent = (weakref.ref(node),i)
(ccont,value) = _traverse_recurse_cache(c,pre,post)
c._parent = None
if isinstance(value,_TraverseException):
return False,value
if not ccont:
if _DEBUG_TRAVERSE and cacheas in _DEBUG_TRAVERSE_ITEMS:
print("Not continuing under %s, caching as %s and returning value"%(str(node),cacheas,value))
node._cache[cacheas] = value
return (False,value)
cvals.append(value)
if post is not None:
try:
(cont,value) = post(node,cvals)
except Exception as e:
import sys
return (False,_TraverseException(e,node,cacheas,sys.exc_info()[2]))
if _DEBUG_TRAVERSE and cacheas in _DEBUG_TRAVERSE_ITEMS:
print("Completed traversal under %s, caching as %s and returning value %s"%(str(node),cacheas,value))
node._cache[cacheas] = value
return (cont,value)
def _traverse_recurse_nocache(node,pre,post):
value = None
cont = True
if pre is not None:
try:
(descend,cont,value) = pre(node)
except Exception as e:
import sys
return (False,_TraverseException(e,node,'pre-traverse '+cacheas,sys.exc_info()[2]))
if not descend:
if _DEBUG_TRAVERSE and cacheas in _DEBUG_TRAVERSE_ITEMS:
print("Not descending under",str(node),", returning value",value)
return (cont,value)
cvals = []
if node._children is not None:
for i,c in enumerate(node._children):
c._parent = (weakref.ref(node),i)
(cont,value) = _traverse_recurse_nocache(c,pre,post)
c._parent = None
if isinstance(value,_TraverseException):
return False,value
if not cont:
if _DEBUG_TRAVERSE and cacheas in _DEBUG_TRAVERSE_ITEMS:
print("Not continuing under",str(node),", returning value",value)
return (False,value)
cvals.append(value)
if post is not None:
try:
(cont,value) = post(node,cvals)
except Exception as e:
import sys
return (False,_TraverseException(e,node,cacheas,sys.exc_info()[2]))
if _DEBUG_TRAVERSE and cacheas in _DEBUG_TRAVERSE_ITEMS:
print("Completed traversal under",str(node),", returning value",value)
return (cont,value)
self._parent = str(cacheas)
if cache:
cont,value = _traverse_recurse_cache(self,pre,post)
if clearcache:
self._clearCache(cacheas)
else:
cont,value = _traverse_recurse_nocache(self,pre,post)
self._parent = None
if isinstance(value,_TraverseException):
raise value.with_traceback(value.exc_info)
return value
def _clearCache(self,key,deep=False):
incache = key in self._cache
if incache or deep:
if incache: del self._cache[key]
if self._children is not None:
for a in self._children:
a._clearCache(key,deep)
if _DEBUG_CACHE:
def checkcleared(n,key):
if key in n._cache:
raise RuntimeError("clearCache failed to actually clear cache[%s] on %s, value %s, in context %s"%(str(key),str(n),str(n._cache),str(self)))
if n._children is not None:
for a in n._children:
checkcleared(a,key)
if not deep:
checkcleared(self,key)
def _print_call_stack(self,indent=0):
"""Prints the traversal stack. Used for debugging purposes."""
from . import symbolic_io
if indent > 0:
print(""*(indent-1), end=' ')
sstr = '##'+str(self)+'##'
if self._parent is None:
print(sstr)
return
if isinstance(self._parent,str):
print(self._parent,"(",sstr,")")
return
def call_stack_str(n,cstr,cindex):
astr = []
for i,a in enumerate(n._children):
if i == cindex:
astr.append(cstr)
else:
astr.append('...')
nstr = symbolic_io._prettyPrintExpr(n,astr,parseCompatible=False)
if n._parent is None or isinstance(n._parent,str):
return nstr
ref = n._parent[0]()
if ref is not None:
return call_stack_str(ref,nstr,n._parent[1])
return nstr
ref = self._parent[0]()
if ref is not None:
print(call_stack_str(ref,sstr,self._parent[1]))
def _signature(self):
"""Returns a tuple that can be used for comparisons and hashing"""
return ()
def __bool__(self):
print("__bool__ was called on",self)
raise ValueError("Can't test the truth value of an Expression.... did you mean to use if_(...)?")
def __not__(self):
return not_(self)
def __and__(self,rhs):
return and_(self,rhs)
def __or__(self,rhs):
return or_(self,rhs)
def __eq__(self,rhs):
return eq(self,rhs)
def __ne__(self,rhs):
return ne(self,rhs)
def __le__(self,rhs):
return le(self,rhs)
def __ge__(self,rhs):
return ge(self,rhs)
def __neg__(self):
return neg(self)
def __add__(self,rhs):
return add(self,rhs)
def __sub__(self,rhs):
return sub(self,rhs)
def __mul__(self,rhs):
return mul(self,rhs)
def __div__(self,rhs):
return div(self,rhs)
def __truediv__(self,rhs):
return div(self,rhs)
def __pow__(self,rhs):
return pow_(self,rhs)
def __radd__(self,lhs):
return add(lhs,self)
def __rsub__(self,lhs):
return sub(lhs,self)
def __rmul__(self,lhs):
return mul(lhs,self)
def __rdiv__(self,lhs):
return div(lhs,self)
def __rtruediv__(self,lhs):
return div(lhs,self)
def __rpow__(self,lhs):
return pow_(lhs,self)
def __getitem__(self,index):
return getitem(self,index)
def __setitem__(self,index,val):
return setitem(self,index,val)
def __getattr__(self,attr):
if attr == 'T':
return transpose(self)
elif attr == 'shape':
return shape(self)
else:
raise AttributeError("'Expression' object has no attribute '%s'" % attr)
def __hash__(self):
return hash(self._signature())
[docs]class Wildcard(Expression):
"""Used for matching / finding"""
def __init__(self,name="*"):
self.name = name
[docs] def match(self,val):
return True
def _eval(self,context):
raise ValueError("Can't evaluate Wildcard expressions")
def _evalf(self,context):
raise ValueError("Can't evaluate Wildcard expressions")
def _deriv(self,var,context,rows):
raise ValueError("Can't take derivative of Wildcard expressions")
[docs]class ConstantExpression(Expression):
def __init__(self,value):
if value is None:
raise ValueError("Can't initialize a ConstantExpression with None")
Expression.__init__(self)
if isinstance(value,ConstantExpression):
self.value = value.value
elif isinstance(value,Expression):
raise ValueError("Can't initialize a ConstantExpression with a non-constant Expression")
else:
self.value = value
def checkType(v):
if hasattr(v,'__iter__') and not isinstance(v,str):
for x in v:
checkType(x)
assert not isinstance(v,ConstantExpression)
checkType(self.value)
[docs] def returnType(self):
return type_of(self.value)
[docs] def isConstant(self):
return True
[docs] def returnConstant(self,context=None):
return True
def _eval(self,context):
return self.value
def _evalf(self,context):
return self.value
[docs] def match(self,val):
if not isinstance(val,Expression): val = expr(val)
if isinstance(val,Wildcard): return True
return isinstance(val,ConstantExpression) and np.all(self.value == val.value)
def __len__(self):
return len(self.value)
def _signature(self):
if isinstance(self.value,_PY_PRIMITIVE_TYPES):
return (self.value,)
elif isinstance(self.value,np.ndarray):
return tuple(self.value.flatten())
else:
return (id(self.value),)
def __bool__(self):
return bool(self.value)
def __not__(self):
return ConstantExpression(not self.value)
def _do_binary_op(self,rhs,func):
if isinstance(rhs,(Expression,Variable)):
if isinstance(rhs,ConstantExpression):
return ConstantExpression(func.func(self.value,rhs.value))
return func(self,rhs)
else:
return ConstantExpression(func.func(self.value,rhs))
def __eq__(self,rhs):
return self._do_binary_op(rhs,eq)
def __ne__(self,rhs):
return self._do_binary_op(rhs,ne)
def __le__(self,rhs):
return self._do_binary_op(rhs,le)
def __ge__(self,rhs):
return self._do_binary_op(rhs,ge)
def __neg__(self):
return ConstantExpression(-self.value)
def __getitem__(self,index):
return self._do_binary_op(index,getitem)
def __getattr__(self,attr):
if attr.startswith("_"):
raise AttributeError("ConstantExpression object has no attribute '%s'" % attr)
if hasattr(self.value,attr):
return getattr(self.value,attr)
raise AttributeError("ConstantExpression object has no attribute '%s'" % attr)
[docs]class UserDataExpression(Expression):
def __init__(self,name):
Expression.__init__(self)
self.name = name
[docs] def returnType(self):
return Type('U')
[docs] def isConstant(self):
return False
[docs] def returnConstant(self,context=None):
if context is None: return False
if isinstance(context,dict):
return self.name in context
else:
return self.name in context.userData
def _eval(self,context):
if context is None:
return self
if isinstance(context,dict):
return context.get(self.name,self)
else:
return context.userData.get(self.name,self)
def _evalf(self,context):
assert context is not None,"Can't evalf UserDataExpressions without a context"
if isinstance(context,dict):
return context[self.name]
else:
return context.userData[self.name]
def _deriv(self,var,context,rows):
print("Attempting derivative of user data expression",self.name,"w.r.t.",var)
if isinstance(var,dict):
return var.get(self.name,0)
else:
if self.name == var:
raise ValueError("Cannot take derivatives with respect to user-data variables")
return 0
[docs] def match(self,val):
if not isinstance(val,Expression): return isinstance(val,str) and val == self.name
if isinstance(val,Wildcard): return True
return isinstance(val,UserDataExpression) and self.name == val.name
[docs] def vars(self,context=None,bound=False):
if context is None or bound:
return [self]
elif isinstance(context,dict):
if self.name in context: return []
else:
if self.name in context.userData: return []
return [self]
def _simplify(self,context=None,depth=None):
if context is not None:
res = context.get(self.name,None)
if isinstance(res,_PY_CONST_TYPES) or isinstance(res,(list,tuple)):
return ConstantExpression(res)
return None
return None
def _signature(self):
return ('$'+self.name,)
def __getattr__(self,attr):
if attr.startswith("_"):
raise AttributeError("UserData object has no attribute '%s'" % attr)
return getattr_(self,attr)
[docs]class VariableExpression(Expression):
def __init__(self,var):
Expression.__init__(self)
self.var = var
[docs] def returnType(self):
return self.var.type
[docs] def isConstant(self):
return False
[docs] def returnConstant(self,context=None):
return (self.var.value is not None)
def _eval(self,context):
if self.var.value is not None: return self.var.value
if isinstance(context,dict):
return context.get(self.var.name,self)
return self
def _evalf(self,context):
if self.var.value is not None: return self.var.value
if isinstance(context,dict):
return context[self.var.name]
raise ValueError("Must provide variable binding or context")
def _deriv(self,var,context,rows):
if isinstance(var,Variable): return self.deriv(var.name)
elif isinstance(var,dict):
return var.get(self.var.name,0)
else:
if not isinstance(var,str):
raise ValueError("Can't take the derivative of an expression w.r.t. anything except a str, Variable, or dict")
if self.var.name == var:
size = self.var.type.size
if size is None:
assert self.var.type.char in SCALAR_TYPES,"Can only take derivative if the size of a variable is known"
return 1
else:
assert len(size) == 1,"Can't take derivative with respect to a non-vector Variable yet?"
return _eye(size[0])
return 0
[docs] def match(self,val):
if isinstance(val,Wildcard): return True
if isinstance(val,Variable): return self.var.name == val.name
return isinstance(val,VariableExpression) and self.var.name == val.var.name
[docs] def vars(self,context=None,bound=False):
if not bound and self.var.value is not None: return []
return [self.var]
def _simplify(self,context=None,depth=None):
if self.var.value is not None: return ConstantExpression(self.var.value)
if isinstance(context,dict): return ConstantExpression(context.get(self.var.name,None))
return None
def _signature(self):
return (self.var.name,)
def __len__(self):
return len(self.var)
[docs]class OperatorExpression(Expression):
"""A compound function of individual Expressions"""
def __init__(self,finfo,args,op=None):
Expression.__init__(self)
self.functionInfo = finfo
if op is not None:
self.op = op
else:
self.op = finfo.func
self.args = []
for a in args:
self.args.append(expr(a))
self._children = self.args
#def __del__(self):
# #remove references in parents list
# #print("OperatorExpression.__del__ called")
# for a in self.args:
# a._parents = [(p,index) for p,index in a._parents if p() is not None]
def _do(self,constargs):
"""Calculates function with the given resolved argument tuple"""
assert not any(isinstance(a,Expression) for a in constargs),"Trying to evaluate an Expression's base function with non-constant inputs?"
assert self.op is not None
try:
return self.op(*constargs)
except Exception as e:
#this will be handled during eval's exception handler
"""
print("Error while evaluating",self.functionInfo.name,"with arguments:")
if self.functionInfo.argNames:
for (name,a) in zip(self.functionInfo.argNames,constargs):
print(" ",name,"=",a)
else:
print(" ",','.join([str(a) for a in constargs]))
print("Error",e)
import traceback
print("Traceback")
traceback.print_exc()
try:
print("Call stack:",)
self._print_call_stack()
except Exception as e2:
print("Hmm... exception occured while calling print_call_stack()?")
print(e2)
import traceback
print("Traceback")
traceback.print_exc()
raise ValueError("Exception "+str(e)+" while evaluating function "+self.functionInfo.name+" with args "+",".join(str(v) for v in constargs))
"""
task = "Evaluating "+self.functionInfo.name+" with args "+",".join(str(v) for v in constargs)
import sys
raise _TraverseException(e,self,task,sys.exc_info()[2]).with_traceback(sys.exc_info()[2])
[docs] def returnType(self):
if 'returnType' in self._cache:
return self._cache['returnType']
if self.functionInfo.returnTypeFunc != None:
try:
type = self.functionInfo.returnTypeFunc(*self.args)
self._cache['returnType'] = type
return type
except Exception as e:
print("Exception while evaluating",self.functionInfo.name,"return type with args",','.join(str(v) for v in self.args))
print(" exception is:",e)
raise
if self.functionInfo.returnType:
self._cache['returnType'] = self.functionInfo.returnType
return self.functionInfo.returnType
self._cache['returnType'] = Type(None)
return Type(None)
[docs] def isConstant(self):
if 'constant' not in self._cache:
if len(self.args) == 0:
self._cache['constant'] = True
else:
self._cache['constant'] = all(a.isConstant() for a in self.args)
return self._cache['constant']
[docs] def returnConstant(self,context=None):
def _returnconstant(node,childargs):
if len(childargs)==0:
if isinstance(node,OperatorExpression): return (True,True)
return (True,node.returnConstant(context))
res = all(childargs)
if not res:
return (False,False)
return (True,True)
return self._traverse(post=_returnconstant,cache=False)
[docs] def eval(self,context=None):
"""Returns a new expression where all bound variables and variables in context
are reduced to constant values. If all arguments are constant, returns a
constant expression."""
if 'eval' in self._cache:
print("WARNING: Expression.eval already has cached value?")
print(" expression",self)
print(" Cached value",self._cache['eval'])
raise ValueError()
input()
res = self._eval(context)
self._clearCache('eval')
return res
def _eval(self,context):
def _preeval(node):
if isinstance(node,OperatorExpression) and hasattr(node.functionInfo,'custom_eval'):
return (False,True,node.functionInfo.custom_eval(*([context]+node.args)))
return (True,True,None)
def _posteval(node,aexprs):
if not isinstance(node,OperatorExpression):
return (True,node._eval(context))
if any(isinstance(a,Expression) for a in aexprs):
if any(a is not b for a,b in zip(aexprs,node.args)):
return (True,OperatorExpression(node.functionInfo,aexprs))
else:
return (True,node)
else:
return (True,node._do(aexprs))
return self._traverse(_preeval,_posteval,cacheas='eval',clearcache=False)
def _evalf(self,context):
def _preeval(node):
if isinstance(node,OperatorExpression) and hasattr(node.functionInfo,'custom_eval'):
res = node.functionInfo.custom_eval(*([context]+node.args))
if isinstance(res,ConstantExpression):
res = res.value
assert not isinstance(res,Expression),"Invalid evalf call on expression with custom eval"
return (False,True,res)
return (True,True,None)
def _posteval(node,aexprs):
if not isinstance(node,OperatorExpression):
return (True,node._evalf(context))
return (True,node._do(aexprs))
return self._traverse(_preeval,_posteval,cacheas='evalf')
[docs] def deriv(self,var,context=None):
"""Returns an expression for the derivative dself/dvar, where all
bound variables and variables in context are reduced to constant
values.
Normally, ``var`` is a ``Variable`` or ``str``.
The derivative/Jacobian size is determined as follows:
- If both var and self are scalar, the result is a scalar.
- If self is scalar and var is non-scalar, the result has the same
shape as var.
- If self is non-scalar and var is scalar, the result the same shape
as self.
- If both self and var are arrays, the result is an n-D array with
shape shape(self)+shape(var).
- If other self or var are compound (i.e., unstructured lists), the
result is an m x n matrix, where m = count(self) and n = count(var).
If no derivative can be determined, this returns None.
If the expression evaluates to a constant, the result will be either 0
or a constant expression.
Another option is for ``var`` to be a dict of variable names to their
derivatives.
"""
if not isinstance(var,dict):
if isinstance(var,Variable):
name = var.name
else:
name = var
vd = self.find(VariableExpression(Variable(var,None)))
if vd is not None:
var = vd.var
else:
ud = self.find(UserDataExpression(name))
if ud is not None:
warnings.warn("symbolic.deriv: taking the derivative w.r.t. user data {}, assuming a numeric value".format(var))
var = VariableExpression(Variable(var,'N'))
else:
raise ValueError("Can't take the derivative w.r.t. undefined variable "+var)
var_irregular = var.type.char not in SCALAR_TYPES + ARRAY_TYPES
try:
var_shape = var.type.shape()
except Exception:
var_shape = shape.optimized(var)
if not var_irregular:
var_scalar = False
if isinstance(var_shape,tuple):
#array type, fixed shape
var_cols = np.product(var_shape)
if var_shape == ():
var_cols = 0
var_dims = len(var_shape)
var_scalar = (var_cols == 0)
else:
#array type, variable shape
var_scalar = False
var_cols = count.optimized(var)
var_dims = dims.optimized(var)
assert is_const(var_dims),"Variable dimensions are not constant? "+str(var_dims)+' type '+str(var.type)
var_dims = to_const(var_dims)
#reshape arrays with 2 or more dims
var_deriv_reshape = (var_dims > 1)
else:
#irregular type
var_cols = count.optimized(var)
if is_const(var_cols):
var_cols = int(to_const(var_cols))
var_scalar = False
var_dims = dims.optimized(var)
#don't reshape hyper-types, keep them as count
var_deriv_reshape = False
#print "Number of columns of deriv variable",var,"is",cols
#if not is_const(cols):
# print "WARNING: variable number of columns",cols,"in derivative Variable type",var.type
if isinstance(var_cols,int) and var_cols == 0:
varderivs = {name:1}
assert var_deriv_reshape == False
elif isinstance(var_cols,int) and var_cols == 1:
varderivs = {name:eye.optimized(var_cols)}
assert var_deriv_reshape == False
else:
if var_deriv_reshape:
varderivs = {name:reshape._call(eye._call(var_cols),flatten.optimized(var_cols,var_shape))}
else:
varderivs = {name:eye._call(var_cols)}
else:
#var is a dict
varderivs = var
var_deriv_reshape = False
#because some functions' derivatives may call Expression.deriv again with the
#same dict, we need to determine the dict size dynamically
var_cols = 0
for (k,v) in varderivs.items():
if is_op(v,'eye'):
var_cols = v.args[0]
elif is_op(v,'reshape'):
var_cols = getitem.optimized(v.args[1],0)
elif is_const(v):
var_cols = 0
break
print("CALLING DERIV",self,"WITH ARG DERIV dict INFERRED COLUMNS",var_cols)
if to_const(var_cols) is not None:
var_cols = to_const(var_cols)
var_scalar = is_scalar(var_cols,0)
var_dims = 0
var_irregular = False
res = self._deriv(varderivs,context,var_cols)
self._clearCache('deriv',deep=True)
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: Derivative of",self,"w.r.t.",list(varderivs.keys()),"is",res)
if res is None:
#error occurred, return None
return None
self_type = self.returnType()
self_irregular = self_type is None or self_type.char not in SCALAR_TYPES + ARRAY_TYPES
self_dims = dims.optimized(self) if self_type is None else self_type.dims()
if not is_const(self_dims):
if _DEBUG_DERIVATIVES:
warnings.warn("symbolic.deriv: can't determine dimensions of self, trying eval")
self_dims = dims.optimized(self.eval(context))
if _DEBUG_DERIVATIVES:
if not is_const(self_dims):
warnings.warn(" ... really cannot determine dimensions of self. Proceeding as though non-scalar")
self_scalar = is_scalar(self_dims,0)
if not var_scalar and not self_scalar:
if to_const(self_dims+1) == 2:
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: doing transpose of",res)
res = transpose(res)
else:
#use the proper transpose for higher dimensional derivatives
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: using tensor transpose for self dims",self_dims,"var dims",var_dims)
if self_dims > 2:
print("WEIRD TENSOR RESULT",self)
res = transpose2(res,flatten(1+range_(dims.optimized(res)-1),0))
else:
#print "symbolic.deriv: not doing transpose of",res
pass
"""
if self_irregular or var_irregular:
print("derivative result",res,"with irregular self?",self_irregular,"or var:",var_irregular)
print("Res shape",shape.optimized(res))
print("Self shape",shape.optimized(self))
print("Var shape",var_shape)
print("Self scalar:",self_scalar,"var scalar:",var_scalar)
"""
if self_scalar and (var_irregular or var_deriv_reshape):
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: reshaping result to",var_shape)
res = reshape(res,var_shape)
if not self_scalar and var_deriv_reshape:
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: reshaping result to",simplify(flatten(shape(res)[0:-1],var_shape)))
res = reshape(res,simplify(flatten(shape(res)[0:-1],var_shape)))
if isinstance(res,OperatorExpression):
rsimp = res._postsimplify(depth=None)
res._clearCache('simplified',deep=True)
if rsimp is not None:
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: Post-simplified derivative from",res,"to",rsimp)
res = rsimp
else:
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: Couldn't simplify")
if isinstance(res,OperatorExpression):
rsimp = res._constant_expansion(context,depth=None)
res._clearCache('simplified',deep=True)
if _DEBUG_DERIVATIVES:
if rsimp is not None:
print("symbolic.deriv: Constant expansion of derivative from",res,"to",rsimp)
if rsimp is not None: return rsimp
return res
return res
def _deriv(self,varderivs,context,rows=None):
#varderivs must contain a map from variable names var to derivatives dvar.
#
#If rows == 0, then each dvar is a derivative of var w.r.t. a scalar parameter x.
#
#If rows > 0, then each dvar is a row-stacked array of shape [rows] + shape(var)
#if var is array-like, or (rows,count(var)) otherwise.
assert len(varderivs) > 0,"Must provide at least one variable derivative"
assert rows is not None
var_scalar = is_scalar(rows,0)
if var_scalar:
assert isinstance(to_const(rows),int),"NUMBER OF ROWS ISNT INTEGER? %s %s"%(str(rows),rows.__class__.__name__)
def _jacobian_shape(node):
return shape.optimized(node)
else:
def _jacobian_shape(node):
nt = node.returnType()
if nt is not None and nt.char in SCALAR_TYPES+ARRAY_TYPES:
return flatten.optimized(rows,shape.optimized(node))
else:
return array.optimized(rows,count.optimized(node))
def _deriv_pre(node):
if isinstance(node,OperatorExpression) and hasattr(node.functionInfo,'custom_eval'):
if node.functionInfo.deriv is None:
return (False,True,None)
if _is_exactly(node.functionInfo.deriv,0):
return (False,True,None)
assert callable(node.functionInfo.deriv),"custom_eval functions needs to define a callable deriv function"
if var_scalar:
assert isinstance(to_const(rows),int),"NUMBER OF ROWS ISNT INTEGER? %s %s"%(str(rows),rows.__class__.__name__)
res = node.functionInfo.deriv(*([context]+node.args+[varderivs,rows]))
if _is_exactly(res,0):
try:
res = zero._call(_jacobian_shape(node))
except Exception:
#there's a problem getting the shape... let's hope that it's ok to return 0
pass
elif res is None:
pass
else:
pass
#print("In branch",node,"deriv",res)
#raise NotImplementedError("Haven't tested this branch... why do we need a transpose?")
#res = transpose.optimized(res)
return (False,True,res)
return (True,True,None)
def _deriv_post(node,cvals):
res = 0
if isinstance(node,OperatorExpression):
if _is_exactly(node.functionInfo.deriv,0):
res = 0
elif node.functionInfo.deriv is None:
if all(_is_exactly(v,0) or is_zero(v) for v in cvals):
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: No derivative for function",node.functionInfo.name,"but we happily had no variables to consider")
print(" Values of arguments")
for v in node.args:
print(" ",v)
print(" Derivatives of arguments")
for v in cvals:
print(" ",simplify(v))
res = 0
else:
warnings.warn("symbolic.deriv: No derivative for function {}".format(node.functionInfo.name))
return True,None
elif all(_is_exactly(v,0) or is_zero(v) for v in cvals):
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: zero derivatives of arguments for",node.functionInfo.name)
res = 0
else:
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: Computing raw derivative of %s(%s)"%(node.functionInfo.name,','.join(str(a) for a in node.args)))
print(" w.r.t.",', '.join('d/d'+k+'='+str(v) for (k,v) in varderivs.items()))
if len(cvals) <= 1:
print(" with argument derivatives",','.join([str(v) for v in cvals]))
else:
print(" with argument derivatives")
for v in cvals:
print(" -",v)
#ACTUALLY DO THE DERIVATIVE
res = node._do_deriv(cvals,rows)
#if _DEBUG_DERIVATIVES:
# print(" Result is %s (type %s)"%(str(simplify(res)),res.__class__.__name__))
# print(" unsimplified",res)
#print "Result of derivative of",node.functionInfo.name,"w.r.t.",[(k,str(v)) for (k,v) in varderivs.iteritems()],"is",str(res)
elif isinstance(node,VariableExpression):
if node.var.name in varderivs:
res = varderivs[node.var.name]
elif isinstance(node,UserDataExpression):
if node.name in varderivs:
res = varderivs[node.name]
elif isinstance(node,ConstantExpression):
res = 0
if res is not None and not _is_exactly(res,0):
dconst = to_const(dims.optimized(node))
if dconst is None:
raise ValueError("Odd, dimensions of node "+str(node)+" don't have constant # of dimensions? dims="+str(dims.optimized(node)))
if not var_scalar and dconst > 1:
rshape = shape.optimized(res)
outshape = _jacobian_shape(node)
if not is_const(rshape) or not is_const(outshape):
if not rshape.match(outshape):
warnings.warn("symbolic.deriv: had to reshape result of derivative of {}\n had type {} need shape {}".format(node,res.returnType(),outshape))
elif not np.array_equal(to_const(rshape),to_const(outshape)):
warnings.warn("symbolic.deriv: had to reshape result of derivative of {}\n had type {} need shape {}".format(node,rshape,outshape))
res = reshape.optimized(res,outshape)
if _is_exactly(res,0) and not _is_exactly(_jacobian_shape(node),()):
#no derivative
try:
res = zero._call(_jacobian_shape(node))
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: Converted zero jacobian of",node,"to size",_jacobian_shape(node))
except Exception as e:
#there's a problem getting the shape... let's hope that it's ok to return 0
if _DEBUG_DERIVATIVES:
print("Uh... couldn't get the jacobian shape of",node)
print("=== REASON ===")
print(e)
print("=== END REASON ===")
pass
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: Derivative of",node,"with arg derivs",cvals,"is (unsimplified)",res)
return True,res
return self._traverse(pre=_deriv_pre,post=_deriv_post,cacheas='deriv',clearcache=False)
def _do_deriv(self,dargs,stackcount):
"""Returns the derivative expression of the output with respect to some
variable x, when dargs is a list [dargs[i]/dx for i = 0,...,nargs-1]
If a derivative is not available, returns None
If a derivative is 0 or a zero vector/matrix, should return 0
If stackcount>0, each entry dargs[i] is an array with shape [stackcount]+shape(args[i]))
if args[i] is array-like, or a matrix with shape (stackcount,count(args[i])) if args[i]
is complex.
If stackcount==0, the result of this operation has shape shape(self).
If stackcount > 0, the result is a matrix with shape [stackcount]+shape(self) if
self is array-like, or a matrix with shape (stackcount,count(self)) if self is complex.
"""
if self.functionInfo.deriv is None: return None
if _is_exactly(self.functionInfo.deriv,0): return 0
assert len(dargs) == len(self.args),"Invalid argument to deriv: not enough arguments"
#short-circuit termination: no derivative available for sub-argument
if any(v is None for v in dargs): return None
#check what kind of derivative / jacobian we should produce
cstackcount = to_const(stackcount)
do_stack = cstackcount is None or cstackcount > 0
self_dims = dims.optimized(self)
self_type = self.returnType()
if not is_const(self_dims):
if _DEBUG_DERIVATIVES:
print("Can't determine dimensionality of expression",self,"type",self.returnType())
print(" dims type",self_dims.__class__.__name__)
print(" Proceeding by assuming non-scalar result")
input("Press enter to continue...")
else:
self_dims = to_const(self_dims)
self_irregular = self_type is None or self_type.char not in SCALAR_TYPES + ARRAY_TYPES
if not self_irregular:
try:
self_shape = self_type.shape()
except Exception:
self_shape = shape.optimized(self)
if do_stack:
jacobian_shape = flatten.optimized(stackcount,self_shape)
jacobian_dims = 1 + self_dims
else:
jacobian_shape = self_shape
jacobian_dims = self_dims
else:
if do_stack:
self_shape = None
self_count = count.optimized(self)
jacobian_shape = array.optimized(stackcount,self_count)
jacobian_dims = 2
else:
self_shape = shape.optimized(self)
jacobian_shape = self_shape
jacobian_dims = self_dims
if do_stack:
#determine whether the arguments' are array-like or not. None indicates non-array like
arg_dims = []
for (a,da) in zip(self.args,dargs):
art = a.returnType()
if art is not None and art.char in SCALAR_TYPES + ARRAY_TYPES:
arg_dims.append(art.dims())
else:
arg_dims.append(None)
res = None
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: Reshaping arguments to",self,"have dims",arg_dims)
if callable(self.functionInfo.rowstackderiv):
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: Using row-stack derivative to",self.functionInfo.name)
if any(v is None for v in dargs): return None
#print("callback rowstackderiv: Shapes",[simplify(shape(da)) for da in dargs])
#print("Needs reshaping?",needs_reshaping)
res = self.functionInfo.rowstackderiv(self.args,dargs)
#print("rowstackderiv,",self.functionInfo.name,"result",res)
elif callable(self.functionInfo.colstackderiv):
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: Using column-stack derivative to",self.functionInfo.name)
daresized = []
for (a,da,adims) in zip(self.args,dargs,arg_dims):
if da is 0:
daresized.append(0)
elif adims is None:
#can't determine dimensions: prepare for the worst case tensor
daresized.append(transpose2(da,flatten(1+range_(adims-1),0)))
elif adims == 0:
daresized.append(da)
elif adims == 1:
daresized.append(transpose.optimized(da))
else:
daresized.append(transpose2(da,list(range(1,adims))+[0]))
res = self.functionInfo.colstackderiv(self.args,daresized)
if res is not None:
if jacobian_dims <= 1:
pass
elif jacobian_dims == 2:
res = transpose.optimized(res)
else:
raise ValueError("Can't rearrange column-stacked jacobians for matrix functions yet")
elif callable(self.functionInfo.deriv):
assert is_const(stackcount),"Can't do functional derivatives yet with variable stack size"
row_derivs = []
#print("DERIV",self.functionInfo.name,"WITH STACK COUNT",stackcount)
#print([str(da) for da in dargs])
for i in range(stackcount):
daresized = []
for a,da,adims in zip(self.args,dargs,arg_dims):
if da is 0:
daresized.append(0)
elif is_op(da,'zero'):
if adims == 0:
daresized.append(0)
else:
daresized.append(zero(shape.optimized(a)))
elif adims is None:
#non-array-like, needs reshaping
daresized.append(reshape.optimized(getitem.optimized(da,i),shape.optimized(a)))
else:
daresized.append(getitem.optimized(da,i))
#print " Reshaped i'th derivatives",[str(da) for da in daresized]
di = self.functionInfo.deriv(self.args,daresized)
dishape = to_const(shape.optimized(di))
if dishape is not None and to_const(self_shape) is not None:
assert dishape == self_shape,"Result of derivative %s should have shape %s, instead has %s"%(self.functionInfo.name,str(self_shape),str(dishape))
if self_irregular:
di = flatten.optimized(di)
row_derivs.append(di)
res = array(*row_derivs)
else:
#individual deriv, rowstackderiv, colstackderiv, or jacobian functions
#order from most efficient to least: rowstackderiv, colstackderiv, jacobian (if stackcount > count(a), deriv
#print("INDIVIDUAL DERIVS",self.functionInfo.name,"WITH STACK COUNT",stackcount)
#print([str(simplify(da)) for da in dargs])
res = 0
for index,da in enumerate(dargs):
if _is_exactly(self.functionInfo.deriv[index],0): continue
if _is_exactly(da,0) or is_op(da,'zero'): continue
adims = arg_dims[index]
needs_reshaping_i = (arg_dims[index] is None)
if self.functionInfo.rowstackderiv is not None and self.functionInfo.rowstackderiv[index] is not None:
inc = self.functionInfo.rowstackderiv[index](*(self.args+[da]))
elif self.functionInfo.colstackderiv is not None and self.functionInfo.colstackderiv[index] is not None:
if adims is None:
daresized = transpose2(da,flatten(1+range_(adims-1),0))
elif adims == 0:
daresized = da
elif adims == 1:
daresized = transpose.optimized(da)
else:
daresized = transpose2(da,list(range(1,adims))+[0])
#print("Stack",simplify(da_stack),"has shape",simplify(shape(da_stack)))
inc = self.functionInfo.colstackderiv[index](*(self.args+[daresized]))
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: Column stack form",index," dargs:",daresized)
print(" raw derivative",inc)
if inc is not None:
#convert from column to row form
if jacobian_dims <= 1:
pass
elif jacobian_dims == 2:
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: doing transpose to convert",inc,"into row form")
inc = transpose.optimized(inc)
else:
raise ValueError("Can't rearrange column-stacked jacobians for matrix functions yet")
else:
use_jacobian = self.functionInfo.jacobian is not None and self.functionInfo.jacobian[index] is not None
if self.functionInfo.deriv is not None and self.functionInfo.deriv[index] is not None:
if adims != 2 and adims != None:
#not flattened?
use_jacobian = False
nargs = to_const(count.optimized(self.args[index]))
if use_jacobian and nargs is not None and (cstackcount is not None and nargs > cstackcount):
#jacobian is less efficient
use_jacobian = False
if use_jacobian:
#print("symbolic.deriv: Using jacobian for",self.functionInfo.name,"arg",index)
J = self.functionInfo.jacobian[index](*self.args)
assert J is not None
Jconst = to_const(J)
if Jconst is not None:
da_cols = shape.optimized(da)[-1]
if is_const(da_cols):
assert Jconst.shape[-1] == to_const(da_cols),"Invalid jacobian size for %s argument %d"%(self.functionInfo.name,index)
inc = dot(da,transpose(J))
else:
#da is stacked, but deriv can't handle them
if is_const(stackcount):
"""
print("arg_shape",arg_shape,"needs reshaping?",needs_reshaping_i)
for i in range(stackcount):
print("Argument derivative",i,":",simplify(reshape.optimized(da[i],arg_shape)))
print("Orig shape",shape.optimized(da[i]))
print("Shape",shape.optimized(reshape.optimized(da[i],arg_shape)))
print("Shape",shape.optimized(reshape(da[i],arg_shape)))
"""
darows = [getitem.optimized(da,i) for i in range(stackcount)]
ashape = shape.optimized(self.args[index])
rows = [self.functionInfo.deriv[index](*(self.args+[reshape.optimized(darow,ashape) if needs_reshaping_i else darow])) for darow in darows]
inc = expr(rows)
else:
#Need an apply_ and concat function
#inc = array(*map_(apply_(self.functionInfo.deriv[index],flatten(self.args,[reshaper(da['i'])])),'i',range_(stackcount))
raise NotImplementedError("Dynamically sized dispatch to stacked derivatives of function %s?"%(self.functionInfo.name,))
if inc is None:
return None
if _is_exactly(res,0):
res = inc
else:
if to_const(shape.optimized(inc)) is not None:
if self_irregular:
assert to_const(shape.optimized(inc)) == (to_const(stackcount),to_const(self_count)),"Invalid (irregular)jacobian size %s, should be %s"%(shape.optimized(inc),(stackcount,self_count))
else:
assert to_const(shape.optimized(inc)) == to_const(jacobian_shape),"Invalid jacobian size %s, should be %s"%(shape.optimized(inc),jacobian_shape)
res = add(res,inc)
#res += inc
if not _is_exactly(res,0):
if res is None:
return res
if self_irregular:
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: I think I need to reshape the result of",self,"to",jacobian_shape)
res = reshape(res,jacobian_shape)
return res
#normal derivative, stackcount = 0
if callable(self.functionInfo.deriv):
return self.functionInfo.deriv(self.args,dargs)
res = 0
assert len(self.functionInfo.deriv) == len(self.args)
for index,da in enumerate(dargs):
if _is_exactly(self.functionInfo.deriv[index],0): continue
if _is_exactly(da,0) or is_op(da,'zero'): continue
if self.functionInfo.deriv[index] is None:
if self.functionInfo.jacobian is not None and self.functionInfo.jacobian[index] is not None:
J = self.functionInfo.jacobian[index](*self.args)
assert J is not None
Jconst = to_const(J)
if Jconst is not None:
da_cols = shape.optimized(da)[-1]
if is_const(da_cols):
assert Jconst.shape[-1] == to_const(da_cols),"Invalid jacobian size for %s argument %d"%(self.functionInfo.name,index)
inc = dot(J,da)
else:
warnings.warn("symbolic.deriv: No partial derivative for function %s argument %d (%s)"%(self.functionInfo.name,index+1,self.functionInfo.argNames[index]))
warnings.warn(" Derivative with respect to argument is {}".format(da))
return None
else:
arg = self.args[index]
# and not is_const(stackcount) or to_const(stackcount) > 0?
if not callable(self.functionInfo.deriv[index]):
#df/di = const (usually 0)
inc = self.functionInfo.deriv[index]
else:
#da is not stacked
#reshape da to the format of self.args[index]
inc = self.functionInfo.deriv[index](*(self.args+[da]))
if _DEBUG_DERIVATIVES:
print("symbolic.deriv: Partial derivative for function",self.functionInfo.name,"argument",index+1,str(simplify(inc)))
if inc is None:
return None
if _is_exactly(res,0):
res = inc
else:
res = add(res,inc)
#res += inc
if SHAPE_CHECKING and not _is_exactly(res,0):
myshape = to_const(shape.optimized(self))
resshape = to_const(shape.optimized(res))
if myshape is not None and resshape is not None:
if myshape != resshape:
warnings.warn("symbolic.deriv: derivative doesn't match my shape? derivative shape %s = shape(%s)"%(resshape,simplify(res)))
warnings.warn(" compared to self's shape %s = shape(%s)"%(myshape,simplify(self)))
return res
[docs] def vars(self,context=None,bound=False):
varset = set()
varlist = []
def _vars_pre(node):
if isinstance(node,OperatorExpression) and hasattr(node.functionInfo,"custom_eval"):
#has a bound inner expression -- don't count temporary variables inside
if node.functionInfo.name in ['subs','map','forall','forsome','summation']:
exprvars = node.args[0].vars(context,bound)
loopvars = node.args[1]
replvars = node.args[2].vars(context,bound)
if is_op(loopvars,'array'):
for v in loopvars.args:
assert isinstance(v,UserDataExpression)
exprvars = [x for x in exprvars if x.name != v.name]
else:
assert isinstance(loopvars,UserDataExpression)
exprvars = [x for x in exprvars if x.name != loopvars.name]
for v in exprvars + replvars:
if v.name not in varset:
varlist.append(v)
varset.add(v.name)
"""
#DEBUGGING
print("Expression free variables",[x.name for x in exprvars],"replaced variables",[x.name for x in replvars])
exprvars += replvars
evars = node.functionInfo.custom_eval(context,*node.args).vars()
print("Custom eval",node.functionInfo.custom_eval(context,*node.args))
print("Test free variables",[x.name for x in evars])
assert all(x.name in [y.name for y in evars] for x in exprvars)
assert all(x.name in [y.name for y in exprvars] for x in evars)
"""
return False,True,None
return True,True,None
def _vars(node,cvals):
if isinstance(node,OperatorExpression):
return True,None
vlist = node.vars(context,bound)
for v in vlist:
if v.name not in varset:
varlist.append(v)
varset.add(v.name)
return True,None
self._traverse(pre=_vars_pre,post=_vars,cacheas='vars')
return varlist
[docs] def match(self,expr):
if isinstance(expr,Wildcard): return True
if not isinstance(expr,OperatorExpression): return False
if self.functionInfo.name != expr.functionInfo.name: return False
if len(self.args) != len(expr.args): return False
if 'depth' in self._cache:
if self._cache['depth'] != expr._cache.get('depth',None): return False
for (a,b) in zip(self.args,expr.args):
if not a.match(b): return False
return True
[docs] def find(self,term):
if not isinstance(term,Expression):
term = expr(term)
#TODO: don't use depth caching when term contains a wildcard?
self.depth(cache=True)
term.depth(cache=True)
tdepth = term._cache['depth']
def _find_pre(node):
nd = node._cache['depth']
if nd > tdepth: return (True,True,None)
elif nd == tdepth:
if node.match(term):
return (False,False,node)
else:
return (True,True,None)
#break of search if the tree is too shallow to contain expr
return (False,True,None)
def _find_post(node,cmatches):
matched = [v for v in cmatches if v is not None]
if len(matched) > 0: return False,matched[0]
return True,None
res = self._traverse(pre=_find_pre,post=_find_post,cacheas='find')
return res
[docs] def replace(self,term,termReplace,error_no_match=True):
if not isinstance(term,Expression):
term = expr(term)
if not isinstance(termReplace,Expression):
termReplace = expr(termReplace)
self.depth(cache=True)
term.depth(cache=True)
tdepth = term._cache['depth']
matched = [False]
def _replace_post(node,cres):
nd = node._cache['depth']
if nd > tdepth:
assert isinstance(node,OperatorExpression)
return (True,OperatorExpression(node.functionInfo,cres,node.op))
elif nd == tdepth:
if node.match(term):
#can't just say matched = True since matched is out of scope
matched[0] = True
return (True,termReplace)
return (True,node)
return (True,node)
res = self._traverse(post=_replace_post,cacheas='replace')
if not matched[0] and error_no_match:
raise ValueError("Unable to find term %s in expression %s"%(str(term),str(self)))
return res
def _presimplify(self,depth):
if _is_exactly(depth,0): return None
newdepth = None if depth is None else depth-1
changed = False
argsChanged = False
expr = self
for a in self.args:
assert isinstance(a,Expression),"Uh... args need to be Expressions, got an arg %s of type %s"%(str(a),a.__class__.__name__)
res = self.functionInfo.simplify(self.args,pre=True)
if res is not None:
expr = res
changed = True
if not isinstance(expr,OperatorExpression): return expr
newargs = []
for i,a in enumerate(expr.args):
a._parent = (weakref.ref(expr),i)
ares = (a._presimplify(newdepth) if isinstance(a,OperatorExpression) else None)
a._parent = None
if ares is not None:
newargs.append(ares)
argsChanged = True
else:
newargs.append(a)
if argsChanged:
return OperatorExpression(expr.functionInfo,newargs,expr.op)
else:
return expr
return None
def _postsimplify(self,depth,mydepth=0):
if 'simplified' in self._cache:
if _DEBUG_SIMPLIFY:
print(" "*mydepth,"Postsimplify",self,"cached to",self._cache["simplified"])
return self._cache['simplified']
if _is_exactly(depth,0):
return None
newdepth = None if depth is None else depth-1
simplified = False
newargs = []
if _DEBUG_SIMPLIFY:
print(" "*mydepth,"Postsimplify",self.functionInfo.name,":",self)
if self.functionInfo.name == 'subs':
#DON'T simplify variable lists or values-as-list
newargs= [self.args[0],self.args[1],self.args[2]]
values = self.args[2]
if is_op(values,'array') or is_op(values,'list'):
newvalues = []
for i,v in enumerate(values.args):
if isinstance(v,OperatorExpression):
v._parent = (weakref.ref(values),i)
vsimp = v._postsimplify(newdepth,mydepth+1)
v._parent = None
else:
vsimp = None
if vsimp is None:
newvalues.append(v)
else:
if _DEBUG_SIMPLIFY:
print(" "*mydepth," Simplified subs value",v,"to",vsimp)
newvalues.append(vsimp)
simplified = True
newargs[-1] = array(*newvalues)
else:
for i,a in enumerate(self.args):
if isinstance(a,OperatorExpression):
a._parent = (weakref.ref(self),i)
asimp = a._postsimplify(newdepth,mydepth+1)
a._parent = None
else:
asimp = None
if asimp is None:
newargs.append(a)
else:
if _DEBUG_SIMPLIFY:
print(" "*mydepth," Simplified arg",a,"to",asimp)
newargs.append(asimp)
simplified = True
for a in newargs:
assert isinstance(a,Expression),"Uh... args need to be Expressions, got an arg %s of type %s"%(str(a),a.__class__.__name__)
#default simplifications
#1. inversion
if len(newargs) == 1 and 'inverse' in self.functionInfo.properties:
if isinstance(newargs[0],OperatorExpression) and newargs[0].functionInfo.name == self.functionInfo.properties['inverse'].name:
res = newargs[0].args[0]
self._cache['simplified'] = res
if _DEBUG_SIMPLIFY:
print(" "*mydepth,"Simplified operation",self.functionInfo.name,"in context",OperatorExpression(self.functionInfo,newargs),"via inverse rule")
if isinstance(res,ConstantExpression):
assert res.value is not None
return res
#2. foldable variable-argument operators
if self.functionInfo.properties.get('foldable',False):
associative = self.functionInfo.properties.get('associative',False)
assert associative,"Foldable functions need to also be associative"
newargs2 = []
for a in newargs:
if isinstance(a,OperatorExpression) and a.functionInfo is self.functionInfo:
newargs2 += a.args
else:
newargs2.append(a)
if len(newargs2) > len(newargs):
simplified = True
if _DEBUG_SIMPLIFY:
print(' '*mydepth,"Folded variable argument operator",self.functionInfo.name,"args from",[str(e) for e in newargs],"to",[str(e) for e in newargs2])
newargs = newargs2
#custom simplification
assert not any(a is None or (isinstance(a,ConstantExpression) and a.value is None )for a in newargs)
for a in newargs:
assert isinstance(a,Expression),"Uh... args need to be Expressions, got an arg %s of type %s"%(str(a),a.__class__.__name__)
try:
res = self.functionInfo.simplify(newargs)
if res is not None:
if _DEBUG_SIMPLIFY:
print(" "*mydepth,"Simplified operation",self.functionInfo.name,"in context",OperatorExpression(self.functionInfo,newargs),"via simplifier to",res)
#print "Args",[str(a) for a in newargs]
if isinstance(res,OperatorExpression):
res2 = res._postsimplify(depth,mydepth)
if res2 is not None:
#should we clear the cache? less efficient for future operations, but better to be on the safe side
res._clearCache('simplified',deep=True)
if _DEBUG_SIMPLIFY:
print("... re-simplified to",res2)
res = res2
if not isinstance(res,Expression):
res = expr(res)
self._cache['simplified'] = res
return res
else:
pass
#print "Unable to simplify",self
#print "Call stack",self._print_call_stack()
except Exception as e:
print(" "*mydepth," Call stack:", end=' ')
self._print_call_stack()
print(" "*mydepth,"Exception:",e)
import traceback
traceback.print_exc()
print(" "*mydepth,"Error post-simplifying function",OperatorExpression(self.functionInfo,newargs,self.op))
raise
return None
#5. If-collapsing: pull all equivalent conditions to front
#op(if(cond,arg1,arg1'),arg2) => if(cond,op(arg1,arg2),op(arg1',arg2)) if the inner op's can be simplified
#op(if(cond,arg1,arg1')) => if(cond,op(arg1),op(arg1'))
#or op(if(cond,arg1,arg1'),if(cond,arg2,arg2')) => if(cond,op(arg1,arg2),op(arg1',arg2'))
if len(newargs) > 0 and any(is_op(a,'if') for a in newargs):
ifcond = None
for a in newargs:
if is_op(a,'if'):
ifcond = a.args[0]
break
docollapse = True
numifs = 0
for i in range(1,len(newargs)):
if is_op(newargs[i],'if'):
numifs += 1
if not ifcond.match(newargs[i].args[0]):
docollapse = False
break
if _DEBUG_SIMPLIFY:
print(" "*mydepth,"Considering if-folding of operator",self.functionInfo.name,"judgement:",docollapse)
if docollapse:
trueargs = []
falseargs = []
for a in newargs:
if is_op(a,'if'):
trueargs.append(a.args[1])
falseargs.append(a.args[2])
else:
trueargs.append(a)
falseargs.append(a)
trueres = OperatorExpression(self.functionInfo,trueargs)
trueres2 = trueres._postsimplify(depth,mydepth+1)
if trueres2 is not None: trueres = trueres2
if _DEBUG_SIMPLIFY:
if trueres2 is not None:
print(" "*mydepth,"True condition simplified to",trueres2)
else:
print(" "*mydepth,"True condition remains",trueres)
falseres = OperatorExpression(self.functionInfo,falseargs)
falseres2 = falseres._postsimplify(depth,mydepth+1)
if falseres2 is not None: falseres = falseres2
if _DEBUG_SIMPLIFY:
if falseres2 is not None:
print(" "*mydepth,"False condition simplified to",falseres2)
else:
print(" "*mydepth,"False condition remains",falseres)
if numifs > 1 or len(newargs)==1 or trueres2 is not None or falseres2 is not None:
res = if_(ifcond,trueres,falseres)
self._cache['simplified'] = res
return res
else:
if _DEBUG_SIMPLIFY:
print(" "*mydepth,"If-folding of operator",self.functionInfo.name,"ignored, no improvement in expression size")
#expand Functions defined as expressions
if isinstance(self.functionInfo.func,Expression):
res = _subs(None,self.functionInfo.func,self.functionInfo.exprArgRefs,newargs,False,False,False)
if isinstance(res,OperatorExpression):
res2 = res._postsimplify(depth,mydepth+1)
if res2 is not None:
res._clearCache('simplified',deep=True)
res = res2
self._cache['simplified'] = res
if _DEBUG_SIMPLIFY:
print(" "*mydepth,"Simplied via Expression espansion to",res)
return res
if simplified:
self._cache['simplified'] = OperatorExpression(self.functionInfo,newargs,self.op)
if _DEBUG_SIMPLIFY:
if len(self.args) > 1:
print(" "*mydepth,"Simplified one or more arguments of operation",self.functionInfo.name,"in context",self)
else:
print(" "*mydepth,"Simplified argument of operation",self.functionInfo.name,"in context",self)
#print "Simplified args result",self._cache
return self._cache['simplified']
self._cache['simplified'] = None
return None
def _constant_expansion(self,context,depth):
if 'simplified' in self._cache:
return self._cache['simplified']
if _is_exactly(depth,0): return None
newdepth = None if depth is None else depth-1
#constant replacement
simplified = False
aconst = []
newargs = []
#print "Simplifying operation",self.functionInfo.name,"in context",self
for i,a in enumerate(self.args):
a._parent = (weakref.ref(self),i)
asimp = (a._constant_expansion(context,newdepth) if isinstance(a,OperatorExpression) else a._simplify(context))
a._parent = None
if asimp is None:
newargs.append(a)
else:
simplified = True
newargs.append(asimp)
if isinstance(newargs[-1],UserDataExpression):
aconst.append(newargs[-1].eval(context))
if aconst[-1] is newargs[-1]:
aconst[-1] = None
else:
aconst.append(to_const(newargs[-1]))
if not any(a is None for a in aconst):
#print "Constant expansion of",self
#print "Arguments",[str(a) for a in newargs]
#print "Evaluated arguments",[str(a) for a in aconst]
#constant
assert not any(isinstance(a,Expression) for a in aconst),"Hmm... we should have gotten rid of any Expressions"
try:
return ConstantExpression(self._do(aconst))
except Exception as e:
print("Error simplifying function",self.functionInfo.name)
print(" Call stack:", end=' ')
self._print_call_stack()
print("Exception:",e)
raise
return None
if self.functionInfo.properties.get('foldable',False):
commutative = self.functionInfo.properties.get('commutative',False)
if commutative:
avar = []
#check constants that can be folded in
for ac,v in zip(aconst,newargs):
if ac is None:
avar.append(v)
aconst = [ac for ac in aconst if ac is not None and not _is_exactly(ac,0)]
if len(aconst) > 1:
const = self.op(*aconst)
if 'foldfunc' in self.functionInfo.properties:
const,stop = self.functionInfo.properties['foldfunc'](const)
if stop:
return ConstantExpression(const)
if const is not None:
newargs = avar + [ConstantExpression(const)]
simplified = True
else:
#fold streaks of constant arguments together
newargs2 = []
istart = None
for i in range(len(aconst)):
if aconst[i] is None:
#non-constant
if istart is not None:
if i-istart > 1:
#fold istart...i-1 into a constant
const = self.op(*[aconst[j] for j in range(istart,i)])
newargs2.append(ConstantExpression(const))
else:
newargs2.append(newargs[istart])
istart = None
newargs2.append(newargs[i])
else:
if istart is None:
istart = i
if istart is not None:
i = len(aconst)
if i-istart > 1:
#fold istart...i-1 into a constant
const = self.op(*[aconst[j] for j in range(istart,i)])
newargs2.append(ConstantExpression(const))
else:
newargs2.append(newargs[istart])
if len(newargs2) < len(newargs):
#print "Simplified",[str(e) for e in newargs],"to",[str(e) for e in newargs2],"by non-commutative folding"
#raw_input()
newargs = newargs2
if simplified:
if _DEBUG_SIMPLIFY:
print(" "*depth,"Constant expansion of",self)
print(" "*depth," with arguments",[str(a) for a in newargs])
#print "Evaluated arguments",[str(a) for a in aconst]
if self.functionInfo.simplifier != None:
for i,a in enumerate(newargs):
if not isinstance(a,Expression):
if a is not 0:
warnings.warn("symbolic.simplify: function %s shouldn't simplify to constant: %s"%(self.functionInfo.name,a))
newargs[i] = expr(a)
try:
res = self.functionInfo.simplifier(*newargs)
if res is not None:
self._cache['simplified'] = res
if _DEBUG_SIMPLIFY:
print("=>",res)
if isinstance(res,ConstantExpression):
assert res.value is not None
return res
else:
pass
#print "Unable to simplify",self
#print "Call stack",self._print_call_stack()
except Exception as e:
#print "Error simplifying function",self.functionInfo.name
#print " Call stack:",
#self._print_call_stack()
#print "Exception:",e
raise
return None
if _DEBUG_SIMPLIFY:
print(" "*depth,"=>",OperatorExpression(self.functionInfo,newargs,self.op))
self._cache['simplified'] = OperatorExpression(self.functionInfo,newargs,self.op)
return self._cache['simplified']
return None
def _simplify(self,context=None,depth=None,constant_expansion=True):
"""Returns a new expression, or None if this cannot be simplified.
Simplification has three phases:
- pre-simplification (top-down): each expression examines its immediate children for
possible symbolic simplification. Then, it pre-simplifies its descendents.
- post-simplification (bottom-up): each expression post-simplifies its descendants, then
examines its new children for possible symbolic simplification.
- constant expansion (bottom-up): any descendants are replaced with constants, and if any
immediate children are changed, symbolic simplification is tried again.
"""
if _DEBUG_SIMPLIFY:
print("Simplifying expression",self,"to depth",depth)
simplest = self
res = self._presimplify(depth)
if res is not None:
if not isinstance(res,OperatorExpression):
if not isinstance(res,Expression): return ConstantExpression(res)
return res
simplest = res
res = simplest._postsimplify(depth)
if _DEBUG_SIMPLIFY:
print("After pre/postsimplify:",res)
simplest._clearCache('simplified',deep=True)
if res is not None:
if not isinstance(res,OperatorExpression):
if not isinstance(res,Expression): return ConstantExpression(res)
return res
simplest = res
if constant_expansion:
res = simplest._constant_expansion(context,depth)
if _DEBUG_SIMPLIFY: print("After constant expansion:",res)
simplest._clearCache('simplified',deep=True)
if res is not None:
simplest = res
if simplest is self:
if _DEBUG_SIMPLIFY: print("... no simplification possible")
return None
if _DEBUG_SIMPLIFY: print("... simplified to",simplest)
if not isinstance(simplest,Expression): return ConstantExpression(simplest)
return simplest
def _signature(self):
asig = [a._signature() for a in self.args]
return (self.functionInfo.name,sum(asig,()))
[docs]def supertype(types):
assert len(types) > 0
tres = Type(types[0])
for t in types[1:]:
if tres.match(t): continue
if tres.is_scalar() and t.is_scalar():
#promote to N
tres.char = 'N'
elif not tres.is_scalar() and not t.is_scalar():
if tres.char == t.char:
if tres.char == 'L':
if tres.size != t.size:
tres.size = None
if tres.subtype is not None:
tres.subtype = supertype([tres.subtype,t.subtype])
else:
if tres.size != t.size:
tres.size = None
if tres.subtype is not None:
tres.subtype = supertype([tres.subtype,t.subtype])
else:
#mixed types
tres.char = 'L'
tres.subtype = None
else:
#mix of scalars and array types
return Type(None)
return tres
[docs]def type_of(x):
if isinstance(x,Type): return x
elif isinstance(x,Variable): return x.type
elif isinstance(x,Expression): return x.returnType()
else:
if isinstance(x,bool): return Type('B')
elif isinstance(x,_PY_INT_TYPES): return Type('I')
elif isinstance(x,_PY_FLOAT_TYPES): return Type('N')
elif isinstance(x,slice): return Type('X')
elif isinstance(x,np.ndarray):
if len(x.shape) == 1:
return Type('V',x.shape[0])
elif len(x.shape) == 2:
return Type('M',x.shape)
else:
return Type('A',x.shape)
elif hasattr(x,'__iter__'):
#determine whether to turn this to a vector, matrix, array, or list
if len(x) == 0:
return Type('V',0)
itemtype = supertype([type_of(v) for v in x])
if itemtype.is_scalar():
return Type('V',len(x))
elif itemtype.char is None:
return Type('L',len(x))
else:
if itemtype.char in 'AVM' and itemtype.size is not None:
if itemtype.char == 'V':
return Type('M',(len(x),itemtype.size))
else:
return Type('A',(len(x),) + itemtype.size)
else:
return Type('L',len(x),itemtype)
else:
#print("object ",x,"is not of known type, setting as user data of type",x.__class__.__name__)
return Type(x.__class__.__name__)
[docs]def const(v):
return ConstantExpression(v)
[docs]def is_const(v,context=None,shallow=False):
"""Returns True if v is a constant value or constant expression.
If context is not None, then Variables and user data that are bound to values are considered
to be constants. Otherwise, they are not considered to be constants.
If shallow=True, this doesn't try to convert Expressions.
"""
if isinstance(v,Variable):
if v.value is not None:
return True
return False
elif isinstance(v,Expression):
if shallow and isinstance(v,OperatorExpression) and not is_zero(v): return False
if context is None: return v.isConstant()
else: return v.isConstant() or v.returnConstant(context)
elif isinstance(v,_PY_CONST_TYPES):
return True
elif isinstance(v,dict):
return len(v) == 0 or all(is_const(x,context) for x in v.values())
elif hasattr(v,'__iter__'):
try:
return len(v) == 0 or all(is_const(x,context) for x in v)
except TypeError:
raise ValueError("Object of type %s has iter but not len"%(v.__class__.__name__))
return False
else:
return True
[docs]def to_const(v,context=None,shallow=False):
"""Returns the value corresponding to a constant value or expression v. Otherwise returns None.
If context is not None, then Variables and user data that are bound to values are considered
to be constants. Otherwise, they are not considered to be constants.
If shallow=True, this doesn't try to convert Expressions.
"""
if isinstance(v,Variable):
if v.value is not None:
return v.value
return None
elif isinstance(v,ConstantExpression):
return v.value
elif isinstance(v,Expression):
if shallow and isinstance(v,OperatorExpression) and not is_zero(v): return None
if context is None:
if v.isConstant():
return v.evalf()
else:
if v.isConstant() or v.returnConstant(context):
return v.evalf(context)
return None
elif isinstance(v,_PY_CONST_TYPES):
return v
elif isinstance(v,dict):
vc = {}
for k,x in v.items():
kc = to_const(x,context,shallow)
if kc is None: return None
vc[k] = kc
return vc
elif hasattr(v,'__iter__'):
if len(v) == 0: return v
vc = v.__class__(to_const(x,context,shallow) for x in v)
if all(x is not None for x in vc):
return vc
return None
else:
return v
[docs]def is_scalar(v,value=None):
"""Returns True if v evaluates to a scalar. If value is provided, then
returns True only if v evaluates to be equal to value"""
if isinstance(v,Variable):
if not v.type.is_scalar(): return False
return value is None or v.value == value
elif isinstance(v,ConstantExpression):
return is_scalar(v.value,value)
elif isinstance(v,Expression):
rt = v.returnType()
if rt is None: return False
if not rt.is_scalar(): return False
return value is None
else:
if not (isinstance(v,(bool,)+_PY_NUMERIC_TYPES) or (hasattr(v,'shape') and v.shape == ())): return False
return value is None or v == value
[docs]def to_scalar(v):
"""Returns the value of v if v is a scalar"""
if isinstance(v,Variable):
if not v.type.is_scalar(): return None
return None
elif isinstance(v,ConstantExpression):
return to_scalar(v.value)
elif isinstance(v,Expression):
return None
else:
if not (isinstance(v,(bool,)+_PY_NUMERIC_TYPES) or (hasattr(v,'shape') and v.shape == ())): return None
return v
[docs]def is_zero(v):
"""Returns true if v represents a 0 value"""
if is_op(v,'zero'): return True
if isinstance(v,Expression): return False
vc = to_const(v)
if vc is not None:
if isinstance(vc,np.ndarray):
return np.all(vc==0)
else:
return vc == 0
return False
[docs]def is_var(v):
"""Returns True if v is equivalent to a stand-alone variable."""
return isinstance(v,Variable) or isinstance(v,VariableExpression)
[docs]def to_var(v):
"""If v is equivalent to a stand-alone variable, returns the Variable. Otherwise returns None."""
if isinstance(v,Variable):
return v
elif isinstance(v,VariableExpression):
return v.var
return None
[docs]def is_sparse(v,threshold='auto'):
"""Returns true if v is a sparse array, with #nonzeros(v) < threshold(shape(v)).
Args:
v (Expression): the array
threshold (optional): either 'auto', a constant, or a function of a
shape. If threshold=auto, the threshold is sqrt(product(shape))
"""
if isinstance(v,ConstantExpression):
v = v.value
if isinstance(v,Expression):
raise ValueError("is_sparse can only be run on constants")
sh = _shape(v)
nnz = np.count_nonzero(v)
if callable(threshold):
threshold = threshold(sh)
elif threshold == 'auto':
threshold = 3*math.sqrt(np.product(sh))
return (nnz < threshold)
[docs]def expr(x):
"""Converts x to an appropriate Expression as follows:
- If x is a Variable, returns a VariableExpression.
- If x is a non-str constant, returns a ConstantExpression.
- If x is a str, then a UserDataExpression is returned.
- If x is already an Expression, x is returned.
- If x is a list containing expressions, then ``array(*x)`` is returned.
This adds some ambiguity to lists containing strings, like
``expr(['x','y'])``. In this case, strings are recursively converted
to UserDataExpressions.
"""
if isinstance(x,Variable):
return VariableExpression(x)
elif isinstance(x,str):
return UserDataExpression(x)
elif isinstance(x,Expression):
return x
elif isinstance(x,(np.ndarray,dict)):
return ConstantExpression(x)
elif hasattr(x,'__iter__'):
try:
xc = x.__class__(to_const(v) if not isinstance(v,str) else None for v in x)
if all(v is not None for v in xc):
return ConstantExpression(xc)
except Exception as e:
print("symbolic.expr: Error doing conversion to list of constants?",x.__class__.__name__,x)
print("=== EXCEPTION ===")
print(e)
print("=== END EXCEPTION ===")
raise
input()
return array(*[expr(v) for v in x])
return ConstantExpression(x)
[docs]def is_expr(x):
return isinstance(x,Expression)
[docs]def is_op(x,opname=None):
if not isinstance(x,OperatorExpression): return False
if opname is not None: return (x.functionInfo.name == opname)
return True
[docs]def deriv(e,var):
return expr(e).deriv(var)
[docs]def simplify(e,context=None,depth=None):
return expr(e).simplify(context,depth)
[docs]def to_monomial(v):
"""If v is a monomial of a single variable, returns (x,d) where x is the Variable and d is the degree.
Otherwise returns None."""
if isinstance(v,Variable):
return (v,1)
elif isinstance(v,Expression):
if isinstance(v,VariableExpression):
return (v.var,1)
if isinstance(v,OperatorExpression):
if v.functionInfo.name == 'pow':
res = to_monomial(v.args[0])
if res is None: return None
return (res[0],mul.optimized(v.args[1],res[1]))
elif v.functionInfo.name == 'mul':
if len(v.args) != 2:
#print("NOT MONOMIAL, ",str(v),"HAS",len(v.args),"ARGS")
return None
assert len(v.args) == 2
res1 = to_monomial(v.args[0])
res2 = to_monomial(v.args[1])
if res1 is None or res2 is None:
return None
if res1[0].name != res2[0].name:
return None
return (res1[0],add.optimized(res1[1],res2[1]))
return None
[docs]def to_polynomial(expr,x=None,const_only=True):
"""If expr is a polynomial of a single variable, returns x,[c1,...,cn],[d1,...,dn] where
expr = c1 x^d1 + ... + cn x^dn. Otherwise returns None.
If x is given, the polynomial must be a polynomial in x. If const_only=True, the coefficients
and degrees must be constants,
"""
if not const_only:
raise NotImplementedError("TODO: extract polynomial with non-constant coefficients")
econst = to_const(expr)
if x is not None and econst is not None:
return x,[econst],[0]
mon = to_monomial(expr)
if mon != None:
if x is None or mon[0].name == x.name:
return (mon[0],[1],[mon[1]])
else:
return None
uvars = expr.vars()
if len(uvars) == 1:
if x is None or uvars[0].name == x.name:
#must be a combination of +,-,*, and pow
if is_op(expr) and expr.functionInfo.name in ['add','sum','sub','neg','mul','pow']:
x = uvars[0]
pargs = []
for a in expr.args:
pa = to_polynomial(a,x)
if pa is None: return None
pargs.append(pa)
if expr.functionInfo.name in ['neg']:
return (x,[-c for c in pargs[0][1]],pargs[0][2])
elif expr.functionInfo.name in ['add','sum','sub']:
if expr.functionInfo.name == 'sub':
pargs[1] = (pargs[1][0],[-c for c in pargs[1][1]],pargs[1][2])
dtoc = {}
for (x,cs,ds) in pargs:
for c,d in zip(cs,ds):
if d in dtoc:
dtoc[d] += c
else:
dtoc[d] = c
ds = sorted(dtoc.keys())
cs = [dtoc[d] for d in ds]
return (x,cs,ds)
elif expr.functionInfo.name == 'mul':
(x,cas,das) = pargs[0]
(x,cbs,dbs) = pargs[1]
dtoc = {}
for (ca,da) in zip(cas,das):
for (cb,db) in zip(cbs,dbs):
d = da + db
if d in dtoc:
dtoc[d] += cb*ca
else:
dtoc[d] = cb*ca
ds = sorted(dtoc.keys())
cs = [dtoc[d] for d in ds]
return (x,cs,ds)
else:
#TODO: raise powers with binomial
return None
return None
def _dims(v):
if hasattr(v,'shape'): return len(v.shape)
if hasattr(v,'__len__'):
return 1
raise ValueError("Computing dims of some unknown thing? "+v.__class__.__name__)
return 0
def _len(v):
assert not isinstance(v,Variable),"Did you mean to use len_?"
if hasattr(v,'__len__'): return len(v)
return 0
def _count(v):
assert not isinstance(v,Variable),"Did you mean to use count?"
return _hyper_count(v)
def _shape(v):
assert not isinstance(v,Variable),"Did you mean to use shape?"
if hasattr(v,'shape'): return v.shape
if hasattr(v,'__len__'):
return _hyper_shape(v)
return ()
def _basis(i,n):
assert n > 0,"basis must be given a positive integer for its second argument"
assert i < n,"basis index must be less than the size"
res = np.zeros(n)
res[i] = 1.0
return res
def _eye(n):
assert not isinstance(n,Variable),"Did you mean to use eye?"
if n == 0: return 1
return np.eye(int(n))
def _zero(n):
assert not isinstance(n,Variable),"Did you mean to use zero?"
if isinstance(n,dict):
#hyper array
res = [None]*len(n)
for (k,v) in n.items():
res[k] = _zero(v)
return res
if hasattr(n,'__iter__'):
if len(n) == 0: return 0
if isinstance(n,np.ndarray):
n = n.astype(np.int64)
else:
n = tuple(int(v) for v in n)
else:
if n == 0: return 0
n = int(n)
return np.zeros(n)
def _reshape(x,s):
"""Reshapes x to the shape s"""
if x is None: return None
if isinstance(s,np.ndarray):
s = tuple(s.astype(np.int64).tolist())
elif isinstance(s,list):
s = tuple(s)
if _shape(x) == s:
return x
assert s != ()
if _is_exactly(_shape(x),()):
#scalar, just return ones
return np.ones(s)*x
if not _is_numpy_shape(s):
return _unravel(_ravel(x),s)
try:
return np.array(x).reshape(s)
except Exception:
flattened = _ravel(x)
return np.array(flattened).reshape(s)
def _transpose(v):
if not hasattr(v,'__iter__'): return v
if hasattr(v,'shape') and len(v.shape) >= 3:
warnings.warn("symbolic.transpose: being applied to tensor? You probably want transpose2")
return np.transpose(v)
def _diag(v):
if not hasattr(v,'__iter__'): return v
return np.diag(v)
def _if(cond,trueval,falseval):
return trueval if cond else falseval
def _max(*args):
assert len(args) > 0,"Cannot perform max with 0 arguments"
if len(args) == 1: return args[0]
return max(args)
def _min(*args):
assert len(args) > 0,"Cannot perform min with 0 arguments"
if len(args) == 1: return args[0]
return min(args)
def _argmax(*args):
assert len(args) > 0,"Cannot perform argmax with 0 arguments"
if len(args) == 1: return 0
return max((v,i) for i,v in enumerate(args))[1]
def _argmin(*args):
assert len(args) > 0,"Cannot perform argmin with 0 arguments"
if len(args) == 1: return 0
return min((v,i) for i,v in enumerate(args))[1]
def _add(*args):
"""Generalized add"""
if len(args) == 0: return 0
try:
return np.sum(args,axis=0)
except ValueError:
pass
res = args[0]
for v in args[1:]:
if isinstance(res,np.ndarray):
res += v
elif hasattr(res,'__iter__'):
if isinstance(v,np.ndarray):
res = np.array(res)+v
else:
if hasattr(v,'__iter__'):
assert len(v) == len(res)
res = vectorops.add(res,v)
else:
res = vectorops.add(res,[v]*len(res))
else:
if isinstance(v,np.ndarray):
res = v + res
elif hasattr(v,'__iter__'):
res = vectorops.add([res]*len(v),v)
else:
res += v
return res
def _sub(a1,a2):
"""Generalized subtraction"""
return np.subtract(a1,a2)
def _mul(*args):
"""Generalized product"""
assert len(args) > 0,"mul needs at least one argument"
if len(args) == 2:
return np.multiply(args[0],args[1])
elif len(args) > 2:
return np.multiply(args[0],_mul(*args[1:]))
else:
return args[0]
def _div(a1,a2):
"""Generalized product"""
return np.divide(a1,a2)
def _pow(x,y):
return np.power(x,y)
def _weightedsum(*arglist):
if len(arglist) == 0: return 0.0
vals = arglist[:len(arglist)/2]
weights = arglist[len(arglist)/2:]
res = _mul(vals[0],weights[0])
for (v,w) in zip(vals,weights)[1:]:
res += _mul(v,w)
return res
def _getitem(vector,index):
if isinstance(index,slice):
start,stop,step =index.start,index.stop,index.step
if isinstance(start,ConstantExpression):
start = start.value
if isinstance(stop,ConstantExpression):
stop = stop.value
if isinstance(step,ConstantExpression):
step = step.value
index = slice(start,stop,step)
if hasattr(index,'__iter__'):
return [vector[i] for i in index]
else:
return vector[index]
def _lazy_getitem(context,vector,index):
#lazy evaluation of list entries
eindex = index.eval(context)
if isinstance(eindex,slice):
start,stop,step = eindex.start,eindex.stop,eindex.step
if isinstance(start,ConstantExpression):
start = start.value
if isinstance(stop,ConstantExpression):
stop = stop.value
if isinstance(step,ConstantExpression):
step = step.value
eindex = slice(start,stop,step)
if is_op(vector,'array'):
#special lazy evaluation for array
if is_const(eindex):
if hasattr(eindex,'__iter__'):
return array([vector.args[i] for i in eindex]).eval(context)
else:
return vector.args[eindex]._eval(context)
return vector._eval(context)[eindex]
else:
evector = vector._eval(context)
if is_const(eindex):
if is_const(evector):
evector = to_const(evector)
if isinstance(eindex,(list,tuple,np.ndarray)):
if isinstance(evector,np.ndarray):
return evector[eindex]
elif isinstance(evector,(list,tuple)):
return np.array(evector)[eindex]
if isinstance(eindex,(list,np.ndarray)):
#fancy indexing
if isinstance(evector,Expression):
return evector[eindex]
#return flatten(*[vector[i] for i in eindex]).eval(context)
return [evector[i] for i in eindex]
return evector[eindex]
warnings.warn("Non-constant getitem? {} class {}".format(eindex,eindex.__class__.__name__))
return evector[eindex]
def _elementary_basis(n,index):
#print "n=",n,"index=",index
d = np.zeros(n)
d[to_const(index)] = 1
return d
def _array(*args):
if len(args)==0: return []
try:
arr = np.array(args)
if arr.dtype.char == 'O': #couldn't make a proper array
return list(args)
return arr
except Exception:
return list(args)
def _array_returnType(*args):
if len(args) == 0:
return Type('L',0)
atypes = [type_of(a) for a in args]
atype = supertype(atypes)
if atype.is_scalar():
res = Type('V',len(args))
res.subtype = atype
return res
if atype.char is not None and atype.char in 'VMA' and atype.size is not None:
res = Type('A',(len(args),)+atype.shape())
res.subtype = atype.subtype
return res
#default
#print("Could not deduce return type of array of arg types",[str(t) for t in atypes])
#if len(atypes) >= 2:
# for i in range(1,len(atypes)):
# print("match",i,"?",atypes[0].match(atypes[1]))
return Type('L',len(args),atypes)
def _array_simplifier(*args):
#reassemble
if len(args) > 0 and all(is_op(a,'getitem') for a in args):
v = args[0].args[0]
i0 = to_const(args[0].args[1])
if i0 is None or not isinstance(i0,_PY_INT_TYPES):
return None
for (i,a) in enumerate(args):
if not is_scalar(a.args[1],i+i0):
return None
lv = to_const(len_.optimized(v))
if lv is not None and i0 == 0 and lv == len(args):
return v
else:
assert lv is None or i0+len(args) <= lv
return v[i0:i0+len(args)]
if len(args) > 0 and all(is_op(a,'zero') for a in args):
ash = [shape.optimized(a) for a in args]
sconst = to_const(ash[0])
if sconst:
convertable = all(to_const(a) == sconst for a in ash[1:])
if convertable:
return zero([len(args),] + list(sconst))
def _list_simplifier(*args):
return _array_simplifier(*args)
def _tuple_simplifier(*args):
return tuple_(_array_simplifier(*args))
def _list_returnType(*args):
atypes = [type_of(a) for a in args]
atype = supertype(atypes)
return Type('L',len(args),atypes)
def _flatten(*args):
if len(args)==0: return []
return _ravel(args)
def _row_stack(*args):
if len(args)==0: return []
if len(args)==1 and hasattr(args[0],'__iter__'):
return np.row_stack(args[0])
return np.row_stack(args)
def _column_stack(*args):
if len(args)==0: return []
if len(args)==1 and hasattr(args[0],'__iter__'):
return np.column_stack(args[0])
return np.column_stack(args)
def _setitem(x,indices,rhs):
if isinstance(indices,(list,tuple)):
xcopy = np.array(x)
elif isinstance(x,(list,tuple)):
xcopy = list(x)
else:
xcopy = np.array(x)
xcopy[indices] = rhs
return xcopy
def _getattr(object,attr):
assert isinstance(attr,str)
res = getattr(object,attr)
if callable(res):
return res()
return res
def _setattr(object,attr,val):
assert isinstance(attr,str) or isinstance(attr,_PY_INT_TYPES)
if callable(getattr(object,attr)):
getattr(object,attr)(val)
else:
setattr(object,attr,val)
return object
def _subs(context,expr,var,value,evalexpr=True,evalvar=True,evalvalue=True):
if not isinstance(expr,Expression):
return expr
if is_op(var,'array') or is_op(var,'list'):
var = var.args
if isinstance(var,ConstantExpression):
#may be an empty list
var = var.value
if hasattr(var,'__iter__'):
if isinstance(value,ConstantExpression):
value = value.value
evalvalue = False
assert hasattr(value,'__iter__') or is_op(value,'array'),"Variable list given, but value is not a list, %s has type %s"%(str(value),value.__class__.__name__)
if is_op(value,'array') or is_op(value,'list'):
value = value.args
assert len(var) == len(value)
#if evalexpr: expr = expr._eval(context)
if evalvar: var = [v._eval(context) for v in var]
if evalvalue: value = [v._eval(context) for v in value]
if all(is_const(x,context) for x in value):
if isinstance(context,Context):
context = context.userData
elif context is None:
context = {}
dummy = _setattr
#this is a bit more efficient than expr.replace(var,value)
oldvals = []
for v,x in zip(var,value):
if not isinstance(v,(VariableExpression,UserDataExpression)):
raise ValueError("var argument must be a list of variables or userDatas, got an instance of "+v.__class__.__name__)
if isinstance(v,VariableExpression):
oldvals.append(v.var.value)
v.var.value = x
elif isinstance(v,UserDataExpression):
oldval = context.get(v.name,dummy) if context is not None else None
context[v.name] = x
oldvals.append(oldval)
res = expr._eval(context)
expr._clearCache('eval')
#restore previous values in context
for v,x in zip(var,oldvals):
if isinstance(v,VariableExpression):
v.var.value = x
else:
if x is dummy:
del context[v.name]
else:
context[v.name] = x
return res
else:
#slow version: replacing variables with expressions
for v,x in zip(var,value):
expr = expr.replace(v,x,error_no_match=False)
return expr
#if evalexpr: expr = expr._eval(context)
#print("original var argument:",var,var.__class__.__name__)
if evalvar: var = var._eval(context)
if evalvalue: value = value._eval(context)
if not isinstance(var,(VariableExpression,UserDataExpression)):
#print("var argument:",var)
raise ValueError("var argument must be a variable or a userData, got an instance of "+var.__class__.__name__)
if is_const(value,context):
if isinstance(var,VariableExpression):
#this is a bit more efficient than expr.replace(var,value)
oldval = var.var.value
var.var.value = to_const(value,context)
assert var.var.value is not None
res = expr._eval(context)
expr._clearCache('eval')
var.var.value = oldval
return res
if isinstance(var,UserDataExpression):
#this is a bit more efficient than expr.replace(var,value)
if isinstance(context,Context):
context = context.userData
elif context is None:
context = {}
dummy = _setattr
oldval = context.get(var.name,dummy) if context is not None else None
context[var.name] = value
res = expr.eval(context)
expr._clearCache('eval')
if oldval is dummy:
del context[var.name]
else:
context[var.name] = oldval
return res
return expr.replace(var,value,error_no_match=False)
def _map(context,expr,var,values):
if isinstance(values,Expression):
values = values._eval(context)
var = var._eval(context)
#expr = expr._eval(context)
if not isinstance(var,(VariableExpression,UserDataExpression)):
raise ValueError("var argument must be a variable or a userData")
if is_op(values,'array'):
values = values.args
if is_op(values):
expr = expr.eval(context)
return map_(expr,var,values)
resargs = [_subs(context,expr,var,value,evalexpr=False,evalvar=False,evalvalue=False) for value in values]
if all(is_const(e,context) for e in resargs):
return array(*resargs).evalf()
else:
return array(*resargs)
def _lazy_all(context,*args):
for i,a in enumerate(args):
a = a._eval(context,clearcache=False)
ac = to_const(a)
if ac is not None and not ac:
for j in range(i+1):
args[j]._clearCache('eval')
return False
for j in range(len(args)):
args[j]._clearCache('eval')
return True
def _lazy_any(context,*args):
for i,a in enumerate(args):
a = a._eval(context,clearcache=False)
ac = to_const(a)
if ac is not None and ac:
for j in range(i+1):
args[j]._clearCache('eval')
return True
for j in range(len(args)):
args[j]._clearCache('eval')
return False
def _forall(context,expr,var,values):
if isinstance(values,Expression):
values = values._eval(context)
var = var._eval(context)
#expr = expr._eval(context)
if not isinstance(var,(VariableExpression,UserDataExpression)):
raise ValueError("var argument must be a variable or a userData")
if is_op(values,'array'):
values = values.args
if is_op(values):
return forall(expr,var,values)
anyexpr = False
resargs = []
for value in values:
e = _subs(context,expr,var,value,evalexpr=False,evalvar=False,evalvalue=False)
ec = to_const(e)
if ec is not None:
if not ec:
return ConstantExpression(False)
else:
anyexpr = True
resargs.append(e)
if anyexpr:
return all_(*resargs)
return ConstantExpression(True)
def _forsome(context,expr,var,values):
if isinstance(values,Expression):
values = values._eval(context)
var = var._eval(context)
#expr = expr._eval(context)
if not isinstance(var,(VariableExpression,UserDataExpression)):
raise ValueError("var argument must be a variable or a userData")
if is_op(values,'array'):
values = values.args
if is_op(values):
return forsome(expr,var,values)
anyexpr = False
resargs = []
for value in values:
e = _subs(context,expr,var,value,evalexpr=False,evalvar=False,evalvalue=False)
ec = to_const(e)
if ec is not None:
if ec:
return ConstantExpression(True)
else:
anyexpr = True
resargs.append(e)
if anyexpr:
return all_(*resargs)
return ConstantExpression(False)
def _summation(context,expr,var,values):
if isinstance(values,Expression):
values = values._eval(context)
var = var._eval(context)
#expr = expr._eval(context)
if not isinstance(var,(VariableExpression,UserDataExpression)):
raise ValueError("var argument must be a variable or a userData")
if is_op(values,'array'):
values = values.args
if is_op(values):
expr = expr.eval(context)
return summation(expr,var,values)
anyexpr = False
resargs = []
for value in values:
e = _subs(context,expr,var,value,evalexpr=False,evalvar=False,evalvalue=False)
resargs.append(e)
if not is_const(e,context):
anyexpr = True
if anyexpr:
return sum_(*resargs)
return sum_(*resargs).eval()
def _propagate_returnType(*args):
return supertype([type_of(a) for a in args])
def _promote_returnType(*args):
if len(args) == 0: return Type('N')
types = [type_of(a) for a in args]
assert len(types) > 0
tres = Type(types[0])
for t in types[1:]:
if tres.match(t): continue
if tres.is_scalar():
if t.is_scalar():
#promote to N
tres.char = 'N'
else:
tres = t
else:
if t.is_scalar():
pass
else:
raise ValueError("Incompatible types %s"%(','.join(str(t) for t in types),))
return tres
def _if_returnType(cond,trueval,falseval):
return supertype([type_of(trueval),type_of(falseval)])
def _returnType1(*args):
return args[0].returnType()
def _returnType2(*args):
return args[1].returnType()
def _returnType3(*args):
return args[2].returnType()
#all of these are symbolic Functions in the default namespace
range_ = Function('range',lambda x:np.array(list(range(x))),['n'],returnType='V')
len_ = Function('len',_len,['x'],returnType='I')
count = Function('count',_count,['x'],returnType='I')
shape = Function('shape',_shape,['x'],returnType='V')
reshape = Function('reshape',_reshape,['x','s'])
transpose = Function('transpose',_transpose,['x'])
transpose2 = Function('transpose2',lambda x,axes:np.transpose(x,axes),['x','axes'])
dims = Function('dims',_dims,['x'],returnType='I')
eye = Function('eye',_eye,['n'],returnType='M')
basis = Function('basis',_basis,['i','n'],returnType='V')
zero = Function('zero',_zero,['n'],returnType='A')
diag = Function('diag',_diag,['x'],returnType='M')
eq = Function('eq',operator.eq,['lhs','rhs'],returnType='B')
ne = Function('ne',operator.ne,['lhs','rhs'],returnType='B')
le = Function('le',operator.le,['lhs','rhs'],returnType='B')
ge = Function('ge',operator.ge,['lhs','rhs'],returnType='B')
not_ = Function('not',operator.not_,['x'],returnType='B')
or_ = Function('or',operator.or_,['x','y'],returnType='B')
and_ = Function('and',operator.and_,['x','y'],returnType='B')
neg = Function('neg',operator.neg,['x'],returnType=_propagate_returnType)
abs_ = Function('abs',np.abs,['x'],returnType=_propagate_returnType)
sign = Function('sign',np.sign,['x'],returnType=_propagate_returnType)
add = Function('add',_add,['x','y'],returnType=_promote_returnType)
sub = Function('sub',_sub,['x','y'],returnType=_promote_returnType)
mul = Function('mul',_mul,'...',returnType=_promote_returnType)
div = Function('div',_div,['x','y'],returnType=_promote_returnType)
pow_ = Function('pow',_pow,['x','y'],returnType=_promote_returnType)
dot = Function('dot',np.dot,['x','y'],'A')
outer = Function('outer',np.outer,['x','y'],'A')
tensordot = Function('tensordot',np.tensordot,['x','y','axes'],'A')
if_ = Function('if',_if,['cond','trueval','falseval'],returnType=_if_returnType)
if_.returnTypeDescription = "either the type of trueval or falseval"
max_ = Function('max',_max,'...',returnType=_propagate_returnType)
min_ = Function('min',_min,'...',returnType=_propagate_returnType)
argmax = Function('argmax',_argmax,'...',returnType='I')
argmin = Function('argmin',_argmin,'...',returnType='I')
cos = Function('cos',np.cos,['x'],returnType=_propagate_returnType)
sin = Function('sin',np.sin,['x'],returnType=_propagate_returnType)
tan = Function('tan',np.tan,['x'],returnType=_propagate_returnType)
arccos = Function('arccos',np.arccos,['x'],returnType=_propagate_returnType)
arcsin = Function('arcsin',np.arcsin,['x'],returnType=_propagate_returnType)
arctan = Function('arctan',np.arctan,['x'],returnType=_propagate_returnType)
arctan2 = Function('arctan2',np.arctan2,['x'],returnType=_propagate_returnType)
sqrt = Function('sqrt',np.sqrt,['x'],returnType=_propagate_returnType)
exp = Function('exp',np.exp,['x'],returnType=_propagate_returnType)
log = Function('log',np.log,['x'],returnType=_propagate_returnType)
ln = Function('ln',np.log,['x'],returnType=_propagate_returnType)
sum_ = Function('sum',_add,'...',returnType=_promote_returnType)
any_ = Function('any',any,'...',returnType='B')
all_ = Function('all',all,'...',returnType='B')
getitem = Function('getitem',_getitem,['vec','index'])
setitem = Function('setitem',_setitem,['vec','index','val'])
#Deprecated in Python 3...
#getslice = Function('getslice',operator.getslice,'...',returnType='L')
getattr_ = Function('getattr',_getattr)
setattr_ = Function('setattr',_setattr)
flatten = Function('flatten',_flatten,'...',returnType='V')
row_stack = Function('row_stack',_row_stack,'...',returnType='M')
column_stack = Function('column_stack',_column_stack,'...',returnType='M')
array = Function('array',_array,'...',returnType='A')
list_ = Function('list',lambda *args:args,'...',returnType='L')
tuple_ = Function('tuple',lambda *args:tuple(args),'...',returnType='L')
zip_ = Function('zip',lambda *args:list(zip(*args)),'...',returnType='L')
weightedsum = Function('weightedsum',_weightedsum,'...',returnType=_propagate_returnType)
subs = Function('subs',_subs,['expr','var','value'],returnType=_returnType1)
map_ = Function('map',_map,['expr','var','values'],returnType='L')
forall = Function('forall',_forall,['expr','var','values'],returnType='B')
forsome = Function('forsome',_forsome,['expr','var','values'],returnType='B')
summation = Function('summation',_summation,['expr','var','values'])
getitem.custom_eval = _lazy_getitem
any_.custom_eval = _lazy_any
all_.custom_eval = _lazy_all
subs.custom_eval = _subs
map_.custom_eval = _map
forall.custom_eval = _forall
forsome.custom_eval = _forsome
summation.custom_eval = _summation
range_.description = """Equivalent to Python's range(n) function."""
dims.description = """Returns the dimensionality of the input. Equivalent to len(shape(x))."""
len_.description = """Equivalent to Python's len(x) function, except if x is a scalar then it evaluates to 0.
If x is a multi-dimensional array, this is the length of its first dimension. Undefined for other
forms of user data."""
count.description = """Evaluates the number of numeric parameters in x. Works with scalars, arrays, and lists.
If x is a scalar then it evaluates to 1. Lists are evaluated recursively. Undefined for other forms of
user data."""
shape.description = """Evaluates to x.shape if x is a Numpy array, (len_(x),) if x is a vector, and () if x is
a scalar. If x is a list, this is (len_(x),)+shape(item) if all of the items have the same shape,
otherwise it is a hyper-shape. (shortcut: "x.shape")"""
reshape.description = """Evaluates to x reshaped to the shape s. If x is a scalar, this evaluates to a constant
matrix."""
transpose.description = """Evaluates to np.transpose(x). (shortcut: "x.T")"""
transpose2.description = """Evaluates to np.transpose(x,axes). """
basis.description = """Evaluates to the i'th elementary basis vector in dimension n."""
eye.description = """Evaluates to np.eye(n) if n > 0, otherwise returns 1."""
zero.description = """Evaluates to np.zeros(s) if s is a matrix shape or scalar > 0, otherwise evaluates to 0."""
diag.description = """Evaluates to np.diag(x)."""
flatten.description = """Evaluates to a vector where all arguments are stacked (concatenated) into a single
vector. Arrays are reduced to vectors by Numpy's flatten(), and complex objects are reduced via
recursive flattening."""
row_stack.description = """Evaluates to a matrix where all arguments are stacked vertically. 1D arrays are
treated as row vectors. If only a single list element is provided, then all arguments are stacked."""
column_stack.description = """Evaluates to a matrix where all arguments are stacked horizontally. 1D arrays are
treated as column vectors. If only a single list element is provided, then all arguments are stacked."""
eq.description = """Evaluates to lhs = rhs (shortcut: "lhs = rhs")."""
ne.description = """Evaluates to lhs != rhs (shortcut: "lhs != rhs")."""
le.description = """Evaluates to lhs <= rhs (shortcut: "lhs <= rhs")."""
ge.description = """Evaluates to lhs >= rhs (shortcut: "lhs >= rhs")."""
not_.description = """Evaluates to not x (shortcut: "not x")."""
or_.description = """Evaluates to x or y (shortcut: "x or y")."""
and_.description = """Evaluates to x and y (shortcut: "x and y")."""
any_.description = """Evaluates to any(*args)"""
all_.description = """Evaluates to all(*args)"""
neg.description = """Evaluates to -x (shorcut "-x")."""
abs_.description = """Evaluates to abs(x) (shortcut: "abs(x)"). Works with arrays too (elementwise)"""
sign.description = """Evaluates the sign of x. Works with arrays too (elementwise)."""
add.description = """Evaluates to x + y (shortcut: "x + y"). Works with arrays too."""
sub.description = """Evaluates to x y (shortcut: "x - y"). Works with arrays too, and vector - scalar."""
mul.description = """Evaluates to x * y (shortcut: "x * y"). Works with arrays too (elementwise multiplication)."""
div.description = """Evaluates to x / y (shortcut: "x / y"). Works with arrays too (elementwise division), and
vector / scalar."""
pow_.description = """Evaluates to pow(x,y) (shortcut "x**y")."""
dot.description = """Evaluates to np.dot(x,y)."""
outer.description = """Evaluates to np.outer(x,y)."""
tensordot.description = """Evaluates to np.tensordot(x,y,axes)."""
max_.description = """Evaluates to the maximum of the arguments."""
min_.description = """Evaluates to the minimum of the arguments."""
argmax.description = """Evaluates to the index of the maximum of the argments."""
argmin.description = """Evaluates to the index of the minimum of the argments."""
cos.description = """Evaluates to np.cos(x)."""
sin.description = """Evaluates to np.sin(x)."""
tan.description = """Evaluates to np.tan(x)."""
arccos.description = """Evaluates to np.arccos(x)."""
arcsin.description = """Evaluates to np.arcsin(x)."""
arctan.description = """Evaluates to np.arctan(x)."""
arctan2.description = """Evaluates to np.arctan2(x)."""
sqrt.description = """Evaluates to math.sqrt(x)."""
exp.description = """Evaluates to math.exp(x)."""
log.description = """Evaluates to math.log(x) (base 10)."""
ln.description = """Evaluates to math.ln(x) (natural log)."""
sum_.description = """Evaluates to sum(args). If arguments are vectors or matrices, then the result is also a vector
or matrix. This is somewhat different behavior from sum(x) if x is a list."""
weightedsum.description = """Evaluates to w1*v1+...+wn*vn. """
getitem.description = """Evaluates to vec[index]. This also supports slices and tuples, as well as lists (Numpy fancy
indexing). (shortcut: "vec[index]")"""
setitem.description = """Evaluates to vec except with vec[index] set to val. Equivalent to Python code "temp = vec[:];
vec[index]=val; return temp" """
getattr_.description = """returns the value of a given attribute under the given object. For example,
getattr_(traj,const("milestones")) gets the milestone list of a user-data Trajectory named traj.
If the result is a function (e.g., a getter method), it will be called with no arguments. For example,
getattr_(robot,const("getJointLimits")) will return the robot's joint limits. (shortcut:
"object.attr", where object is a UserDataExpression)"""
setattr_.description = """returns a modified version of the given class object, where value is assigned to the
attribute attr. For example, setattr_(traj,const("times"),[0,0.5,0.1]) sets the times attribute of a
Trajectory to [0,0.5,1]. Note: this operation modifies the object itself and returns it.
If the attribute is a function (e.g., a setter method), it will be called with the argument val."""
if_.description = """If cond evaluates to True, this expression evaluates to trueval. Otherwise, it evaluates to
falseval."""
array.description = """Creates a list or numpy array of the given arguments. This can accept arbitrary arguments,
such as variable size vectors. A Numpy array is produced only if the items have compatible types
and dimensions."""
list_.description = """Creates a list of the given arguments. This can accept arbitrary arguments, such as variable
size vectors. No attempt is made to convert to a numpy array."""
tuple_.description = """Creates a tuple of the given arguments. This can accept arbitrary arguments, such as variable
size vectors."""
zip_.description = """Equivalent to the Python zip function, returning a list of tuples."""
subs.description = """evaluates expr with var substituted with value. For example, subs(const(2)*"i","i",3.5)
yields 2*3.5"""
map_.description = """like the Python map function, evaluates to a list where each entry evaluates expr with var
substituted with a value from the list values. For example, if x is a Variable, map_(x**"i","i",range(3))
yields the list [x**0, x**1, x**2]"""
forall.description = """True if, for every value in the list values, expr evaluates to nonzero when var is substituted
with that value. Equivalent to all_(*[subs(expr,var,value) for value in values])"""
forsome.description = """True if, for some value in the list values, expr evaluates to nonzero when var is substituted
with that value. Equivalent to any_(*[subs(expr,var,value) for value in values])"""
summation.description = """The sum of expr over var when var is substitued with each value in the list values.
Equivalent to sum_(*[subs(expr,var,value) for value in values])"""
_builtin_functions = {'dims':dims,'len':len_,'count':count,'shape':shape,'reshape':reshape,'transpose':transpose,'transpose2':transpose2,
'range':range_,
'eye':eye,'basis':basis,'zero':zero,'diag':diag,
'eq':eq,'ne':ne,'ge':ge,'le':le,
'not':not_,'or':or_,'and':and_,
'neg':neg,'abs':abs_,'sign':sign,
'add':add,'sub':sub,'mul':mul,'div':div,
'pow':pow_,
'dot':dot,
'if':if_,
'max':max_,'min':min_,
'argmax':argmax,
'argmin':argmin,
'cos':cos,'sin':sin,'tan':tan,'sqrt':sqrt,
'arccos':arccos,'arcsin':arcsin,'arctan':arctan,'arctan2':arctan2,
'sum':sum_,
'any':any_,'all':all_,
'getitem':getitem, 'setitem':setitem,
#Deprecated in Python 3
#'getslice':getslice,
'getattr':getattr_,'setattr':setattr_,
'flatten':flatten,'row_stack':row_stack,'column_stack':column_stack,
'array':array,'list':list_,'tuple':tuple_,'zip':zip_,
'weightedsum':weightedsum,
'subs':subs,'map':map_,'forall':forall,'forsome':forsome,'summation':summation
}
_infix_operators = {'and':' and ','or':' or ','sub':'-','div':'/','pow':'**','ge':'>=','le':'<=','eq':'==','ne':'!='}
_prefix_operators = {'not':'not ','neg':'-'}
_python_naming_changes = {'and':'and_',
'or':'or_',
'not':'not_',
'len':'len_',
'abs':'abs_',
'if':'if_',
'pow':'pow_',
'sum':'sum_',
'any':'any_',
'max':'max_',
'min':'min_',
'all':'all_',
'list':'list_',
'tuple':'tuple_',
'zip':'zip_'
}
#define simplifiers and derivatives
def _dims_simplifier(x):
try:
res = x.returnType().dims()
if res is not None: return res
except Exception as e:
pass
if isinstance(x,Variable):
try:
res = x.type.dims()
if res is not None: return res
except Exception as e:
pass
if is_op(x):
if is_op(x,'getitem'):
assert NotImplementedError("Shouldn't get here")
if isinstance(to_const(x.args[1]),(list,tuple,slice)):
return dims(x.args[0])
else:
return dims(x.args[0])-1
if is_op(x,'setitem'):
assert NotImplementedError("Shouldn't get here")
return dims(x.args[0])
if is_op(x,'reshape'):
if is_op(x.args[1],'shape'):
return dims(x.args[1].args[0])
return len_(x.args[1])
elif is_const(x):
if isinstance(x,np.ndarray):
return len(x.shape)
try:
return type_of(x).dims()
except Exception:
#no way to determine dimensions of unknown type
return None
return None
def _len_simplifier(x):
if is_scalar(x): return 0
if isinstance(x,OperatorExpression):
try:
res = x.returnType().len()
if res is not None: return res
except Exception:
pass
if is_op(x,'zero'):
if hasattr(x.args[0],'__iter__'):
return x.args[0][0]
return x.args[0]
if is_op(x,'basis'):
return x.args[1]
if is_op(x,'eye'):
return x.args[0]
if is_op(x,'diag'):
return len_(x.args[0])
if is_op(x,'getitem'):
if isinstance(to_const(x.args[1]),(list,tuple)):
return (len(x.args[1]),)
if is_op(x,'if'):
cond,trueval,falseval = x.args
return if_(cond,len_(trueval),len_(falseval))
if is_op(x,'flatten'):
return sum_(*[count._call(a) for a in x.args])
if is_op(x,'shape'):
rt = x.args[0].returnType()
if rt is None: return None
if rt.char == 'A': return None if rt.size is None else len(rt.size)
elif rt.char == 'V': return 1
elif rt.char == 'M': return 2
elif rt.char == 'L': return rt.size
else: return 0
elif isinstance(x,VariableExpression):
if x.var.type.size is not None:
return x.var.type.len()
elif isinstance(x,Variable):
try:
return x.type.len()
except Exception as e:
pass
elif hasattr(x,'__len__'):
return len(x)
#print "Can't yet simplify len(",str(x),")","return type",x.returnType()
return None
def _count_simplifier(x):
if is_scalar(x): return 1
if is_op(x):
try:
c = x.returnType().count()
if c is not None:
return c
except Exception as e:
pass
elif isinstance(x,VariableExpression):
if x.var.type.size is not None:
return x.var.type.count()
elif isinstance(x,Variable):
if x.type.size is not None:
return x.type.count()
if is_op(x,'shape'):
return dims(x.args[0])
return None
def _shape_simplifier(x):
if is_scalar(x): return ()
try:
res = x.returnType().shape()
if res is not None:
return res
except Exception as e:
pass
if isinstance(x,VariableExpression):
if x.var.type.size is not None:
return x.var.type.shape()
elif isinstance(x,Variable):
if x.type.size is not None:
return x.type.shape()
#TEST: new simplifier
"""
if is_op(x):
if is_op(x,'zero'):
return x.args[0]
if is_op(x,'basis'):
return x.args[1]
if is_op(x,'eye'):
return flatten(x.args[0],x.args[0])
if is_op(x,'diag'):
return flatten(len_(x.args[0]),len_(x.args[0]))
if is_op(x,'dot'):
return dotshape(x.args[0],x.args[1])
if is_op(x,'getitem') and isinstance(to_const(x.args[1]),(list,tuple)):
return (len_(x.args[1]),)
if is_op(x,'setitem'):
return shape(x.args[0])
if is_op(x,'reshape'):
return x.args[1]
if is_op(x,'flatten'):
counts = [count(a) for a in x.args]
return array(sum(counts))
#print "Can't yet simplify shape(",str(x),")","return type",x.returnType()
"""
return None
def _eye_returnType(N):
Nc = to_const(N)
if Nc is not None:
return Type('M',(Nc,Nc))
return Type('M')
def _basis_returnType(i,n):
Nc = to_const(n)
if Nc is not None:
return Type('V',Nc)
return Type('V')
def _zero_returnType(sh):
shc = to_const(sh)
if shc is not None:
if isinstance(shc,int):
if shc == 0:
return Type('N')
else:
return Type('V',shc)
if len(shc) == 0:
return Type('N')
elif len(shc) == 1:
return Type('V',shc[0])
elif len(shc) == 2:
return Type('M',shc)
else:
return Type('A',shc)
try:
d = sh.returnType().dims()
if d <= 1:
return Type('V')
elif d == 2:
return Type('M')
else:
return Type('A')
except Exception:
return Type('A')
def _reshape_returnType(x,sh):
shc = to_const(sh)
if shc is not None:
return Type('A',shc)
l = simplify(len_(sh))
lc = to_const(l)
if lc is not None:
if lc == 0:
return Type('N')
elif lc == 1:
return Type('V')
return Type('A')
"""
#is there some way to getting the length of the shape?
try:
print("reshape returntype not const",sh)
d = sh.returnType().dims()
if d <= 1:
return Type('V')
else:
#might have a hyper-shape?
return Type('L')
except Exception:
return Type('A')
"""
def _reshape_presimplifier(x,sh):
if is_op(x,'reshape'):
return reshape(x.args[0],sh)
if is_op(x,'flatten'):
return reshape(x.args[0],sh)
if is_op(x,'zero'):
return zero(sh)
try:
xsh = x.returnType().shape()
except Exception:
return None
if is_const(sh):
shc = to_const(sh)
if shc is not None:
if np.array_equal(xsh,shc):
return x
if len(shc) == 1:
return flatten(x)
elif len(shc) == 0:
assert