"""Basic definitions for autodifferentiation.

Overview
========

This module defines a simple autodiff system for Klamp't objects, including
3D math, kinematics, dynamics, and geometry functions. This is useful for
all sorts of optimization functions, e.g., machine learning, calibration,
and trajectory optimization.

This can be used standalone, but is more powerful in conjunction with
PyTorch or CadADi. Use the functions in :mod:pytorch and :mod:casadi
to convert these Klamp't functions to PyTorch or CasADi functions.

Calling and evaluating functions
================================

When you call an ADFunctionInterface object on some arguments, the function
isn't immediately called. Instead, it creates an expression DAG that can
either be evaluated using eval(), or a derivative can be taken with
derivative(), or the expression itself can be passed as an argument to

Many built-in Python numeric functions and some Numpy functions are made
available in autodiff variants. Standard numeric operators like +, -, *, /, **,
abs, and sum will automatically create autodiff-able expressions.

Array indexing works properly on expressions.  To create larger arrays from
scalars or smaller arrays, you must use the ad.stack function, you can't
just write [expr1,expr2].

Variables are named by strings. If you pass a string to a function, it creates
an expression that is dependent on a variable.  It can later be evaluated by
passing the named variables as arguments to eval().  For example::

import numpy as np
print(mul('x','y').eval(x=2,y=np.array([3.0,4.0]))) #gives the vector [6,8]

Note that if want to use variables in Python operators, you can't simply do
something like 'x'+'y' since Python sees this as string addition. Instead,
you should explicitly declare a variable using the var keyword:
var('x') + var('y').

Everything evaluates to a scalar or 1D array
============================================

An important note is that all arguments to functions, return values, and
variable assignments must either be scalars or 1D numpy arrays.

Derivatives of a function are always of shape (size(f(x)),size(x)).

If you are creating your own function that can accept either scalars or
of a scalar is 1.

.. note::
Arrays should be float type! Weird results can result if you, for example
initialize your arrays with ints, e.g., np.array([0,0,0])

Expression DAG and performance
================================

The expression DAG can be nested arbitrarily deeply. For improved performance,
you can reuse the same sub-expression multiple times in a DAG. For example, in
the following setup::

diff = var('x')**2 - var('y')**2
poly = diff**3 - 2.0*diff**2 + 4.0*diff - 6
poly.eval(x=2.5,y=1.5)

the term diff will only be evaluated once and then reused through the
evaluation of poly. Similarly, derivatives will be smart in reusing
common subexpressions constructed in this fashion.  For complex subexpressions,
this can be much faster than evaluating the full expression.

Note, however, it is the user's job to construct expressions such that common
subexpressions are the same Python object.  So the following code::

x = var('x')
y = var('y')
poly = (x**2-y**2)**3 - 2.0*(x**2-y**2)**2 + 4.0*(x**2-y**2) - 6
poly.eval(x=2.5,y=1.5)

doesn't know that the term (x**2-y**2) is repeated, and instead it will be
recalculated three times.

Caching is only done within a single eval call. If you would like to have
multiple expressions evaluated with a single eval call, you can either put them
in a stack expression, or use the :func:eval_multiple function.

If you print an expression DAG that contains repeated subexpressions, you will
see terms like #X and @X.  The term #X after a subexpression indicates that
this subexpression may be reused, and @X indicates an instance of reuse.

Derivatives
===========

To get the derivative of an expression, call :func:ADFunctionCall.derivative.
This will use the chain rule to handle arbitrarily nested expressions.  A
derivative is always a 2D array, even if the expression or the argument is
scalar.

Finite differences is performed automatically if some internal function doesn't
have derivative information defined.  To modify the finite difference step size,
modify FINITE_DIFFERENCE_RES.

===========================

The easiest way to create your own function is the :func:function function.
Just pass a Python function to it, indicate the input and output sizes (for
better error handling), and any available derivatives as keywords.  This will
then automatically generate an ADFunctionInterface instance that can be called
to produce an auto-differentable function.

You can either provide a derivative function, J(arg,x1,...,xn), which
returns the Jacobian :math:\\frac{df}{dx_{arg}}(x1,..,xn) , or a jvp
function Jdx(arg,dx,x1,...,xn), which calculates the jacobian-vector
product :math:\\frac{df}{dx_{arg}}(x1,...,xn)\cdot dx .  More conveniently,
you can just provide a list of functions Ji(x1,...,xn), with i=1,..,n,
with Ji giving :math:\\frac{df}{dx_i}(x1,...,xn) , or in the JVP form,
functions Jdxi(dx,x1,...,xn) each giving
:math:\\frac{df}{dx_{i}}(x1,...,xn)\cdot dx.  See the :func:function
documentation for more details.

Example::

def my_func(x,y):
return x**2 - y**2

derivative=[lambda x,y:np.array(2*x).resize((1,1)),
lambda x,y:np.array(-2*y).resize((1,1))])  #gives Jacobians
jvp=[lambda dx,x,y:2*x*dx, lambda dy,x,y:-2*y*dy])  #gives Jacobian-vector products

#now you can do a bunch of things...
print(my_func_ad4(var('x'),var('x')).derivative('x',x=2))  #evaluates the derivative, which is always [[0]]
print(my_func_ad4(var('x'),2.0).derivative('x',x=2))  #evaluates the derivative, which is 4
print(my_func_ad(var('x'),2.0).derivative('x',x=2))   #uses finite differencing to approximate the derivative

For more sophisticated functions, such as those that need to be constructed
with some custom user data, you may subclass the :class:ADFunctionInterface
interface class. At the very least, you should overload the eval method,
although you will likely also want to overload n_args, n_in, n_out, and
one of the derivative methods.

The same example as above::

def __str__(self):
return 'my_func'
def n_args(self):
return 2
def n_in(self,arg):
return 1
def n_out(self):
return 1
def eval(self,x,y):
return x**2 - y**2
def jvp(self,arg,darg,x,y):
if arg == 0:  #x
return 2*x*darg
else: #y
return -2*y*darg

print(my_func_ad4(var('x'),var('x')).derivative('x',x=2))  #evaluates the derivative, which is always [[0]]
print(my_func_ad4(var('x'),2.0).derivative('x',x=2))  #evaluates the derivative, which is 4

If you are going to be taking a lot of derivatives w.r.t. vector variables,
then the most efficient way to implement your functions is to provide the
derivative method.  Otherwise, the most efficient way is to provide the
jvp method.  The reason is that taking the derivative w.r.t. a vector will
create intermediate matrices, and to perform the chain rule, derivative
will perform a single call and a matrix multiply.  On the other hand, jvp
will be called once for each column in the matrix.  If you are taking
derivatives w.r.t. scalars, jvp won't create the intermediate Jacobian
matrix, so it will be faster.  Implementing both will let the system figure out
the best alternative.

When you create a new function with its derivatives, you ought to run it
through check_derivatives to see whether they match with finite differences.
To modify the finite difference step size, modify FINITE_DIFFERENCE_RES.

=========

This module defines :func:var, :func:function, and the helper functions
:func:check_derivatives, :func:eval_multiple, :func:derivative_multiple
and :func:finite_differences.

The operators :data:stack, :data:sum_, :data:minimum, and :data:maximum
operate somewhat like their Numpy counterparts:

- stack acts like np.stack, since all items are 1D arrays.
- sum\_ acts like Numpy's sum if only one item is provided.  Otherwise, it
acts like the builtin sum.
- minimum/maximum act like ndarray.min/max if only one item is provided, and
otherwise they act like Numpy's minimum/maximum. If more than 2 items are
provided, they are applied elementwise and sequentially.

The operators add, sub, mul, div, neg, pow\_, abs\_, and getitem are also
provided here, but you will probably just want to use the standard Python
operators +, -, *, /, -, **, abs, and [].  On arrays, they act just like
the standard numpy functions.

Another convenience function is :func:setitem, which has the signature
setitem(x,indices,v) and is equivalent to::

xcopy = np.copy(x)
xcopy[indices] = v
return xcopy

which is sort of a proper functional version of the normal x[indices]=v.

The function :data:cond acts like an if statement, where
cond(pred,posval,negval) returns posval if pred > 0 and negval otherwise.
There's also cond3(pred,posval,zeroval,negval) that returns posval if
pred > 0, zeroval if pred==0, and negval otherwise.  Conditions can also be
applied when pred is an array, but it must be of equal size to posval and
negval.  The result in this case is an array with the i'th value taking on
posval[i] when pred[i] > 0 and negval[i] otherwise.

=======================

autodiff versions of many Klamp't functions are found in:

- :mod:math_ad : basic vector math, including functions from the
:mod:~klampt.math.vectorops module.
- :mod:so3_ad : operations on SO(3), mostly compatible with the
:mod:~klampt.math.so3 module.
- :mod:se3_ad : operations on SE(3), mostly compatible with the
:mod:~klampt.math.se3 module.
- :mod:kinematics_ad kinematics functions, e.g., robot Jacobians.
- :mod:geometry_ad geometry derivative functions.
- :mod:dynamics_ad dynamics derivative functions.
- :mod:trajectory_ad trajectory derivative functions.

"""

import numpy as np

FINITE_DIFFERENCE_RES = 1e-6
"""The resolution used for finite differences."""

"""The base class for a new auto-differentiable function.

The function has the form :math:f(x_0,...,x_k) with k+1 ==
self.n_args().  All inputs and outputs are either scalars or numpy
ndarrays.

To define a new function, you need to implement n_args(),
eval(*args), and optionally n_in(arg), n_out(),
derivative(arg,*args), and jvp(arg,darg,*args).

A function class can be instantiated with non-numeric data.  For example,
to perform forward kinematics the call::

(see :mod:kinematics_ad) creates a function wp(q) that will accept a
robot configuration q and return the world position of the point
localPos on link link.

This makes these function classes more expressive, but also leads to a
slight annoyance in that to define a function that can be called, you need
to instantiate a function.  For exaple, the sin function is defined in
:mod:math_ad as sin = ADSinFunction() where ADSinFunction is
the actual class that implements the sin function.  This construction lets
you call math_ad.sin(var('x')) as though it were a normal function.

You can also call the convenience routine
fn_name = function(func,'fn_name',argspec,outspec) to do
just this.  See the documentation of :func:function for more inf.

The general convention is that CapitalizedFunctions are function classes
that must be instantiated, while lowercase_functions are already
instantiated.
"""

def __str__(self):
"""Returns a descriptive string for this function"""
return self.__class__.__name__

[docs]    def argname(self,arg):
"""Returns a descriptive string for argument #arg"""
return "Arg %d"%(arg,)

[docs]    def n_args(self):
"""Returns the number of arguments."""
raise NotImplementedError()

[docs]    def n_in(self,arg):
"""Returns the number of entries in argument #arg.  If 1, this can
either be a 1-D vector or a scalar.  If -1, the function can accept a
variable sized argument.
"""
raise NotImplementedError()

[docs]    def n_out(self):
"""Returns the number of entries in the output of the function. If -1,
this can output a variable sized argument.
"""
raise NotImplementedError()

[docs]    def eval(self,*args):
"""Evaluates the application of the function to the given
(instantiated) arguments.

Args:
args (list): a list of arguments, which are either ndarrays or
scalars.
"""
raise NotImplementedError()

[docs]    def derivative(self,arg,*args):
"""Returns the Jacobian of the function w.r.t. argument #arg.

Args:
arg (int): A value from 0,...,self.n_args()-1 indicating that we
wish to take :math:df/dx_{arg}.
args (list of ndarrays): arguments for the function.

Returns:
ndarray: A numpy array of shape (self.n_out(),self.n_in(arg)).
Keep in mind that even if the argument or result is a scalar, this
needs to be a 2D array.

If the derivative is not implemented, raise a NotImplementedError.

If the derivative is zero, can just return 0 (the integer)
regardless of the size of the result.
"""
raise NotImplementedError()

[docs]    def gen_derivative(self,arg,*args):
"""Generalized derivative that allows higher-order derivatives to be
taken.

Args:
arg (list): Indicates the order of derivatives to be taken.  For
example, to take the 2nd derivative w.r.t. x0,x1,
arg = [0,1] should be specified.
args (list of ndarrays): arguments for the function.

Returns:
ndarray: A tensor of shape (n_out,n_in(arg[0]),...,
n_in(arg[-1]).

If the generalized derivative is not implemented, raise a
NotImplementedError.

If the generalized derivative is zero, can just return 0 (the
integer) regardless of the size of the result.
"""
raise NotImplementedError()

[docs]    def jvp(self,arg,darg,*args):
"""Performs a Jacobian-vector product for argument #arg.

Args:
arg (int): A value from 0,...,self.n_args()-1 indicating that we
wish to calculate df/dx_arg * darg.
darg (ndarray): the derivative of x_arg w.r.t some other parameter.
Must have darg.shape = (self.n_in(arg),).
args (list of ndarrays): arguments for the function.

Returns:
ndarray: A numpy array of length self.n_out()

If the derivative is not implemented, raise a NotImplementedError.
"""
raise NotImplementedError()

def __call__(self,*args):
try:
na = self.n_args()
if na >= 0 and len(args) != na:
raise ValueError("Invalid number of arguments to function %s, wanted %d, got %d"%(str(self),na,len(args)))
except NotImplementedError:
pass

"""A symbolic value that will be bound later.

Consider using :func:var rather than calling this explicitly.
"""
def __init__(self,name):
self.name = name
def __str__(self):
return self.name
def __neg__(self):
return neg(self)
def __sub__(self,rhs):
return sub(self,rhs)
def __mul__(self,rhs):
return mul(self,rhs)
def __div__(self,rhs):
return div(self,rhs)
def __pow__(self,rhs):
return pow_(self,rhs)
def __truediv__(self,rhs):
return div(self,rhs)
def __pow__(self,rhs):
return pow_(self,rhs)
def __rsub__(self,lhs):
return sub(lhs,self)
def __rmul__(self,lhs):
return mul(lhs,self)
def __rdiv__(self,lhs):
return div(lhs,self)
def __rtruediv__(self,lhs):
return div(lhs,self)
def __rpow__(self,lhs):
return pow_(lhs,self)
def __getitem__(self,index):
return getitem(self,index)
def __abs__(self):
return abs_(self)

"""A function call in an expression tree."""
def __init__(self,func,args):
self.func = func
self.args = [a for a in args]
for i,arg in enumerate(args):
if isinstance(arg,str):
pass
try:
narg = arg.func.n_out()
nin = func.n_in(i)
if nin >= 0 and narg >= 0 and nin != narg:
raise ValueError("Invalid size of argument %s=%s to function %s: %d != %d"%(func.argname(i),str(arg),str(func),narg,nin))
except NotImplementedError:
#argument sizes not specified, ignore
pass
raise ValueError("Can't pass a ADFunctionInterface directly as an argument")
else:
#a constant
try:
nin = func.n_in(i)
if nin >= 0:
if _size(arg) != nin:
raise ValueError("Function %s expected %s to have size %d, instead got %d",self.func.name,func.argname(i),nin,_size(arg))
except NotImplementedError:
#argument sizes not specified, ignore
pass
pass
self._cache = {}

def __str__(self):
terms = {}
dag = self._print_dag(terms)
self._clear_cache('print_id')
def _flatten_str_tree(tree):
flat = []
for v in tree:
if isinstance(v,list):
flat.append(_flatten_str_tree(v))
else:
flat.append(v)
return ''.join(flat)
return _flatten_str_tree(dag)

[docs]    def eval(self,**kwargs):
"""Evaluate a function call with bound values given by the keyword
arguments.  Example::

print((3*var('x')**2 - 2*var('x')).eval(x=1.5))

"""
self._clear_cache('eval_result','eval_context','eval_instantiated_args')
return self._eval([],kwargs)

[docs]    def derivative(self,arg,**kwargs):
"""Evaluates the derivative of this computation graph with respect to
the named variable at the given setings in kwargs.

If any derivatives are not defined, finite differences will be used
with step size FINITE_DIFFERENCE_RES.

If the prior eval() call wasn't with the same arguments, then this will
call eval() again with kwargs.  So, the sequence eval(**kwargs),
derivative(**kwargs) will save some time compared to changing the
args.
"""
if 'eval_context' not in self._cache or not _context_equal(self._cache['eval_context'],kwargs):
self.eval(**kwargs)
assert _context_equal(self._cache['eval_context'],kwargs)
self._clear_cache('deriv_result')
res = self._deriv([],arg,kwargs)
if res is 0:
return np.zeros((_size(self._cache['eval_result']),_size(kwargs[arg])))
if len(res.shape) == 1:
#deriv w.r.t. scalar
res = res.reshape((res.shape[0],1))
return res

[docs]    def gen_derivative(self,args,**kwargs):
"""Evaluates the generalized derivative of this computation graph with
respect to the named variables in args, at the given settings in
kwargs.

Examples:

- (var('x')**3).gen_derivative(['x','x']) returns the
second derivative of :math:x^3 with respect to x.
- var('x')**2*var('y').gen_derivative(['x','y']) returns
:math:\\frac{d^2}{dx dy} x^2 y.

If args is [], this is equivalent to self.eval(**kwargs).

If args has only one element, this is equivalent to
self.derivative(args[0],**kwargs).

Args:
args (list of str): the denominators of the derivative to be
evaluated.
kwargs (dict): the values at which the derivative is evaluated.

Result:
ndarray: A tensor of shape
(_size(this),_size(args[0]),...,_size(args[-1])) containing
the derivatives.
"""
if len(args)==0:
return self.eval(**kwargs)
elif len(args)==1:
return self.derivative(args[0],**kwargs)
elif len(args)==2:
if 'eval_context' not in self._cache or not _context_equal(self._cache['eval_context'],kwargs):
self.eval(**kwargs)
assert _context_equal(self._cache['eval_context'],kwargs)
self._clear_cache('hessian_result','deriv_darg1','deriv_darg2')
self.derivative(args[0],**kwargs)
self._move_cache('deriv_result','deriv_darg1')
self.derivative(args[1],**kwargs)
self._move_cache('deriv_result','deriv_darg2')
res = self._hessian([],args[0],args[1],kwargs)
hess_shape = (_size(self._cache['eval_result']),)+tuple(_size(kwargs[arg]) for arg in args)
if res is 0:
return np.zeros(hess_shape)
return res
else:
raise NotImplementedError("TODO: higher-order derivatives")

[docs]    def replace(self,**kwargs):
"""Replaces any symbols whose names are keys in kwargs with the
matching value as an expression.  The result is an ADFunctionCall.

For example, (var('x')**2 + var('y')).replace(x=var('z'),y=3) would
return the expression var('z')**2 + 3.
"""
res = self._replace(kwargs)
self._clear_cache('replace_result')
return res

[docs]    def terminals(self,sort=False):
"""Returns the list of terminals in this expression.  If sort=True,
they are sorted by alphabetical order.

Returns:
list of str
"""
res = []
resset = set()
self._terminals(res,resset)
self._clear_cache('terminals_visited')
if sort:
return sorted(res)
return res

def _print_dag(self,terms):
assert 'print_id' not in self._cache
id = len(terms)
self._cache['print_id'] = id
argstr = []
for v in self.args:
argstr.append(str(v))
if 'print_id' in v._cache:
vid = v._cache['print_id']
argstr.append('@'+str(vid))
if not terms[vid][-1].startswith('#'):
terms[vid].append('#'+str(vid))
else:
argstr.append(v._print_dag(terms))
else:
argstr.append(str(v))
argstr.append(',')
argstr.pop(-1)
res = [str(self.func),'(',argstr,')']
terms[id] = res
return res

def _eval(self,callStack,kwargs):
if 'eval_result' in self._cache:
return self._cache['eval_result']
self._cache['eval_context'] = kwargs
instantiated_args = []
callStack.append(self)
for i,arg in enumerate(self.args):
if arg.name not in kwargs:
raise ValueError("Function argument "+arg.name+" not present in arguments")
instantiated_args.append(kwargs[arg.name])
if np.ndim(instantiated_args[-1]) > 1:
raise ValueError("Terminal",arg.name,"is an array with more than 1D? Can only support 1D arrays")
instantiated_args.append(arg._eval(callStack,kwargs))
if np.ndim(instantiated_args[-1]) > 1:
raise ValueError("Function",arg.func.name,"returned array with more than 1D? Can only support 1D arrays")
else:
instantiated_args.append(arg)
try:
ni = self.func.n_in(i)
if ni >= 0:
if _size(instantiated_args[-1]) != ni:
raise ValueError("Error evaluating %s: argument %d has expected size %d, instead got %d"%(str(self),i,ni,_size(instantiated_args[-1])))
except NotImplementedError:
#not implemented, just ignore
pass

callStack.pop(-1)
self._cache['eval_instantiated_args'] = instantiated_args
try:
res = self.func.eval(*instantiated_args)
except Exception:
context = [str(s) for s in callStack]
raise RuntimeError("Error running function %s in context %s"%(str(self.func),' -> '.join(context)))
try:
no = self.func.n_out()
if not isinstance(no,int):
raise ValueError("Error evaluating %s: return # of outputs is not integer? got %s"%(str(self),str(no)))
if no >= 0:
if _size(res) != no:
raise ValueError("Error evaluating %s: return value has expected size %d, instead got %d"%(str(self),no,_size(res)))
except NotImplementedError:
#not implemented, just ignore
pass
self._cache['eval_result'] = res
return res

def _deriv_jacobian(self,arg,args):
n = _size(args[arg])
no = -1
try:
no = self.func.n_out()
except NotImplementedError:
pass
try:
df = self.func.derivative(arg,*args)
if len(df.shape) != 2:
raise ValueError("Function %s.derivative doesn't return a 2D array?  Context is %s, result has shape %s"%(str(self.func),str(self),df.shape))
if no >= 0 and df.shape != (no,n):
raise ValueError("Function %s.derivative has wrong size?  Context is %s, expects %s, result has shape %s"%(str(self.func),str(self),(no,n),df.shape))
return df
except NotImplementedError:
pass
try:
df = 0
temp = np.zeros(n)
for col in range(n):
temp[col] = 1
dcol = self.func.jvp(arg,temp,*args)
if df is 0:
df = np.empty((_size(dcol),n))
if no >= 0 and _size(dcol) != no:
raise ValueError("Function %s.jvp result has wrong size? Context is %s, expects %d, result has size %d"%(str(self.func),str(self),no,_size(dcol)))
df[:,col] = dcol
temp[col] = 0
return df
except NotImplementedError:
pass
return finite_differences(lambda x:self.func.eval(*(args[:arg]+[x]+args[arg+1:])),args[arg],FINITE_DIFFERENCE_RES)

def _deriv_jacobian_array_product(self,arg,mat,args):
n = _size(args[arg])
no = -1
try:
no = self.func.n_out()
except NotImplementedError:
pass
if len(mat.shape) == 1 or mat.shape[1] == 1:
#try 1-D jvp first, then derivative, then finite difference
if len(mat.shape) == 2:
mat = mat[:,0]
if mat.shape[0] != n:
raise ValueError("1D derivative of arg %s of function %s expected size %d, got %d. Context is %s"%(self.func.argname(arg),str(self.func),n,mat.shape[0],str(self)))
try:
res = self.func.jvp(arg,mat,*args)
if _scalar(res):
#for jvps, if the result is a scalar this needs to be resized to a 1d array
return np.asarray(res).reshape((1,))
if len(res.shape) != 1:
raise ValueError("Function %s.jvp doesn't return a 1D array or scalar?  Context is %s, result has shape %s"%(str(self.func),str(self),res.shape))
if no >= 0 and res.shape[0] != no:
raise ValueError("Function %s.jvp result has wrong size? Context is %s, expects %d, result has size %d"%(str(self.func),str(self),no,res.shape[0]))
return res
except NotImplementedError:
pass
try:
df = self.func.derivative(arg,*args)
if len(df.shape) != 2:
raise ValueError("Function %s.derivative doesn't return a 2D array?  Context is %s, result has shape %s"%(str(self.func),str(self),df.shape))
if no >= 0 and df.shape != (no,n):
raise ValueError("Function %s.derivative has wrong size?  Context is %s, expects %s, result has shape %s"%(str(self.func),str(self),(no,n),df.shape))
return np.dot(df,mat)
except NotImplementedError:
pass
return finite_differences(lambda t:self.func.eval(*(args[:arg]+[args[arg]+t*mat]+args[arg+1:])),0,FINITE_DIFFERENCE_RES)
else:
#try derivative * mat first, then directional jvp, then fd * mat
try:
df = self.func.derivative(arg,*args)
if len(df.shape) != 2:
raise ValueError("Function %s.derivative doesn't return a 2D array?  Context is %s, result has shape %s"%(str(self.func),str(self),df.shape))
return np.dot(df,mat)
except NotImplementedError:
pass
try:
columns = []
if mat.shape[0] != n:
raise ValueError("Derivative of arg %s of function %s expected size %d, got %d. Context is %s"%(self.func.argname(arg),str(self.func),n,mat.shape[0],str(self)))
for col in range(mat.shape[1]):
columns.append(self.func.jvp(arg,mat[:,col],*args))
if no >= 0 and _size(columns[-1]) != no:
raise ValueError("Function %s.jvp result has wrong size? Context is %s, expects %d, result has size %d"%(str(self.func),str(self),no,_size(columns[-1])))
return np.column_stack(columns)
except NotImplementedError:
pass
return np.dot(finite_differences(lambda x:self.func.eval(*(args[:arg]+[x]+args[arg+1:])),args[arg],FINITE_DIFFERENCE_RES),mat)

def _deriv(self,callStack,deriv_arg,kwargs):
if 'deriv_result' in self._cache:
return self._cache['deriv_result']
assert 'eval_context' in self._cache
eval_result = self._cache['eval_result']
eval_args = self._cache['eval_instantiated_args']
callStack.append(self)
res = 0
for i,arg in enumerate(self.args):
if arg.name == deriv_arg:
inc = self._deriv_jacobian(i,eval_args)
assert hasattr(inc,'shape') and inc.shape[0] == _size(eval_result),"Function %s has invalid jacobian shape? Context %s"%(str(self.func),str(self))
if _scalar(eval_args[i]):
#for scalar derivatives, we just use a 1D vector rather than a 2D array
inc = inc[:,0]
else:
#not deriv_arg, no derivative
continue
da = arg._deriv(callStack,deriv_arg,kwargs)
if da is 0:
continue
inc = self._deriv_jacobian_array_product(i,da,eval_args)
assert hasattr(inc,'shape') and inc.shape[0] == _size(eval_result),"Function %s has invalid jacobian shape? Context %s"%(str(self.func),str(self))
else:
#const, no derivative
continue
if res is 0:
res = inc
else:
res += inc
callStack.pop(-1)
self._cache['deriv_result'] = res
return res

def _hessian(self,callStack,darg1,darg2,kwargs):
#d/dv f(x1,...,xn) = sum_{i=1...,n} d/dxi f(...)* dxi/dv
#d/du d/dv f(x1,...,xn) = sum_{i=1...,n} (sum_{j=1...,n} (d^2/dxixj) f(...)* dxi/dv * dxj/du) + d/dxi f(...) (d^2 / duv) xi
if 'hessian_result' in self._cache:
return self._cache['hessian_result']
assert 'eval_context' in self._cache
assert 'deriv_darg1' in self._cache
assert 'deriv_darg2' in self._cache
eval_result = self._cache['eval_result']
eval_args = self._cache['eval_instantiated_args']
callStack.append(self)
res = 0
#add d/dxi f * (d^2/duv) xi to res
for i,arg in enumerate(self.args):
#no second derivative of a constant
continue
da = arg._hessian(callStack,darg1,darg2,kwargs)
if da is 0:
continue
inc = np.empty((_size(eval_result),_size(kwargs[darg1]),_size(kwargs[darg2])))
for i in range(da.shape[-1]):
inc[:,:,i] = self._deriv_jacobian_array_product(i,da[:,:,i],eval_args)
print("Deriv * hessian",arg," = ",inc)
else:
#const, no derivative
continue
if res is 0:
res = inc
else:
res += inc
#add (d^2/dxixj) f(...)* dxi/dv * dxj/du
for i,arg in enumerate(self.args):
if arg.name != darg1:
continue
if arg._cache['deriv_arg1'] is 0:
continue
else:
#constant
continue
for j,arg2 in enumerate(self.args):
inc = 0
if arg2.name != darg2:
continue
if arg2._cache['deriv_arg2'] is 0:
continue
else:
#constant
continue
try:
ddf = self.func.gen_derivative([i,j],*eval_args)
except NotImplementedError:
return finite_differences()
#print("d^2/d",i,j,"of func",self.func,"is",ddf)
if ddf is 0:
continue
#print("d^2/d",i,j,"of terminal",arg.name,"and terminal",arg2.name)
inc = ddf
#print("d^2/d",darg1,darg2," of terminal",arg.name,"and arg",j)
inc = np.tensordot(ddf,(2,0),arg2._cache['deriv_arg2'])
#print("d^2/d",i,j," of arg",i,"and terminal",arg2.name)
inc = np.tensordot(ddf,(1,0),arg._cache['deriv_arg1'])
#print("d^2/d",darg1,darg2," of arg",i," and arg",j)
inc = np.tensordot(np.tensordot(ddf,(1,0),arg._cache['deriv_arg1']),(1,0),arg._cache['deriv_arg2'])
print("hessian",arg,"of func",self.func,"*d^2/dx",i,"dx",j," = ",inc)
if res is 0:
res = inc
else:
res += inc
callStack.pop(-1)
self._cache['hessian_result'] = res
return res

def _replace(self,kwargs):
if 'replace_result' in self._cache:
return self._cache['replace_result']
replaced = False
newargs = []
for arg in self.args:
if arg.name in kwargs:
newargs.append(kwargs[arg.name])
replaced = True
else:
newargs.append(arg)
newargs.append(arg._replace(kwargs))
if newargs[-1] is not arg:
replaced = True
else:
newargs.append(arg)
if replaced:
else:
#circular reference... can we make sure this won't mess up the GC?
res = self
self._cache['replace_result'] = res
return res

def _terminals(self,res,resset):
if 'terminals_visited' in self._cache:
return
for arg in self.args:
if arg.name not in resset:
res.append(arg.name)
arg._terminals(res,resset)
self._cache['terminals_visited'] = True
return

def _clear_cache(self,*args):
if args[0] not in self._cache:
return
for a in args:
del self._cache[a]
for v in self.args:
v._clear_cache(*args)

def _move_cache(self,src,dest):
if src not in self._cache:
return
self._cache[dest] = self._cache[src]
del self._cache[src]
for v in self.args:
v._move_cache(src,dest)

def __neg__(self):
return neg(self)
def __sub__(self,rhs):
return sub(self,rhs)
def __mul__(self,rhs):
return mul(self,rhs)
def __div__(self,rhs):
return div(self,rhs)
def __pow__(self,rhs):
return pow_(self,rhs)
def __truediv__(self,rhs):
return div(self,rhs)
def __pow__(self,rhs):
return pow_(self,rhs)
def __rsub__(self,lhs):
return sub(lhs,self)
def __rmul__(self,lhs):
return mul(lhs,self)
def __rdiv__(self,lhs):
return div(lhs,self)
def __rtruediv__(self,lhs):
return div(lhs,self)
def __rpow__(self,lhs):
return pow_(lhs,self)
def __getitem__(self,index):
return getitem(self,index)
def __abs__(self):
return abs_(self)

[docs]def var(x):
"""Creates a new variable symbol with name x."""

[docs]def function(func,name='auto',argspec='auto',outspec=None,argnames=None,
derivative=None,jvp=None,gen_derivative=None,order=None):
"""Declares and instantiates a ADFunctionInterface object that calls the
Python function func.

Args:
func (callable): a Python function.
name (str, optional): the string that the ADFunctionInterface prints.
argspec (list, optional): a list of argument sizes, or 'auto' to auto-
set the argument list from the function signature.
outspec (int, optional): a size of the output.
argnames
derivative (callable or list of callables, optional): one or more
derivative functions.  If a callable, takes the index-arguments
tuple (arg,*args).  If a list, each element just takes the same
arguments as func.
jvp (callable or list of callables, optional): one or more Jacobian-
vector product functions. If a single callable, has signature
(arg,darg,*args).  If a list, each element has the signature
(darg,*args).
gen_derivative (callable, list of callables, or dict, optional): a
higher-order derivative function.

If a callable is given, it has standard signature (arg,*args).

If a list of callables is given, they are a list of higher-order
derivatives, starting at the second derivative.  In this case,
the function may only have 1 argument, and each callable has the
signature (*args).

If a dict is given, it must map tuples to callables, each of which
has the signature (*args).
order (int, optional): if provided, we know that all derivatives higher
than this order will be zero.
"""
if name == 'auto':
import inspect
for x in inspect.getmembers(func):
if x[0] == '__name__':
name = x[1]
break
assert name != None,"Could not get the name of the function?"
if argspec == 'auto':
import inspect
(argnames,varargs,keywords,defaults) = inspect.getargspec(func)
if varargs != None or keywords != None:
#variable number of arguments
raise ValueError("Can't handle variable # of arguments yet")
else:
argspec = [-1]*len(argnames)
if argnames is not None:
if len(argnames) != len(argspec):
raise ValueError("Invalid number of argument names provided")
def run_if_not_none(f,*args):
if f is None:
raise NotImplementedError()
else:
return f(*args)
if hasattr(derivative,'__iter__'):
derivative_list = derivative
derivative = lambda arg,*args:run_if_not_none(derivative_list[arg],*args)
if hasattr(jvp,'__iter__'):
jvp_list = jvp
jvp = lambda arg,darg,*args:run_if_not_none(jvp_list[arg],darg,*args)
if gen_derivative is None and order is not None:
def default_gen_deriv(arg,*args):
if len(arg) > order:
return 0
if len(arg) == 1:
return derivative(arg[0],*args)
raise NotImplementedError()
gen_derivative = default_gen_deriv
if isinstance(gen_derivative,dict):
gen_derivative_dict = gen_derivative
def dict_gen_deriv(arg,*args):
if len(arg) == 1:
return derivative(arg[0],*args)
if tuple(arg) in gen_derivative_dict:
return gen_derivative_dict(tuple(arg),*args)
if order is not None and len(tuple) > order:
return 0
raise NotImplementedError()
gen_derivative = dict_gen_deriv
elif hasattr(gen_derivative,'__iter__'):
if len(argspec) > 1:
raise ValueError("Cannot accept a list of higher-order derivatives except for a single-argument function")
gen_derivative_list = gen_derivative
def list_gen_deriv(arg,*args):
if len(arg) == 1:
return derivative(arg[0],*args)
if len(arg)-2 < len(gen_derivative_list):
if gen_derivative_list[len(arg)-1] is None:
raise NotImplementedError()
return gen_derivative_list[len(arg)-1](*args)
if order is not None and len(tuple) > order:
return 0
raise NotImplementedError()
members = {}
members['__str__'] = lambda self:name
members['eval'] = lambda self,*args:np.asarray(func(*args))
members['n_args'] = lambda self:len(argspec)
members['n_in'] = lambda self,arg:argspec[arg]
if outspec is not None:
members['n_out'] = lambda self:outspec
def toarray(res):
return res if _scalar(res) else np.asarray(res)
if derivative is not None:
members['derivative'] = lambda self,arg,*args:toarray(derivative(arg,*args))
if jvp is not None:
members['jvp'] = lambda self,arg,darg,*args:toarray(jvp(arg,darg,*args))
if gen_derivative is not None:
members['gen_derivative'] = lambda self,arg,*args:toarray(gen_derivative(arg,*args))
if argnames is not None:
members['argname'] = lambda self,arg:argnames[arg]
return TempType()

[docs]def finite_differences(f,x,h=1e-6):
"""Performs forward differences to approximate the derivative of f at x.
f can be a vector or scalar function, and x can be a scalar or vector.

The result is a Numpy array of shape (_size(f(x)),_size(x)).

h is the step size.
"""
f0 = f(x)
g = np.empty((_size(x),_size(f0)))
if hasattr(x,'__iter__'):
temp = x.astype(float,copy=True)        #really frustrating bug: if x is an integral type, then adding h doesn't do anything
for i in range(len(x)):
temp[i] += h
g[i] = f(temp)-f0
temp[i] = x[i]
else:
g[0,:] = f(x+h) - f0
g *= 1.0/h
return g.T

[docs]def finite_differences_hessian(f,x,h):
"""Performs forward differences to approximate the Hessian of f(x) w.r.t.
x.
"""
f0 = f(x)
g = np.empty((_size(f0),_size(x),_size(x)))
if hasattr(x,'__iter__'):
xtemp = x.astype(float,copy=True)
fx = []
for i in range(len(x)):
xtemp[i] += h
fx.append(np.copy(f(xtemp)))
xtemp[i] = x[i]
for i in range(len(x)):
xtemp[i] -= h
g[:,i,i] = (fx[i] - 2*f0 + f(xtemp))
xtemp[i] = x[i]
xtemp[i] += h
for j in range(i):
xtemp[j] += h
g[:,j,i] = g[:,i,j] = (f(xtemp) - fx[i]) - (fx[j] - f0)
xtemp[j] = x[j]
xtemp[i] = x[i]
else:
fx = f(x+h)
g[:,0,0] = (fx - 2*f0 + f(x-h))
return g/h**2

[docs]def finite_differences_hessian2(f,x,y,hx,hy=None):
"""Performs forward differences to approximate the Hessian of f(x,y) w.r.t.
x and y.
"""
if hy is None:
hy = hx
f0 = f(x,y)
g = np.empty((_size(f0),_size(x),_size(y)))
if hasattr(x,'__iter__'):
xtemp = x.astype(float,copy=True)
fx = []
for i in range(len(x)):
xtemp[i] += hx
fx.append(np.copy(f(xtemp,y)))
xtemp[i] = x[i]

else:
fx = f(x+hx,y)

if hasattr(y,'__iter__'):
fy = []
ytemp = y.astype(float,copy=True)
for j in range(len(y)):
ytemp[j] += hy
fy.append(np.copy(f(x,ytemp)))
ytemp[j] = y[j]
else:
fy = f(x,y+hy)

if hasattr(x,'__iter__'):
if hasattr(y,'__iter__'):
for i in range(len(x)):
xtemp[i] += hx
for j in range(len(y)):
ytemp[j] += hy
g[:,i,j] = (f(xtemp,ytemp)-fx[i]) - (fy[j] - f0)
ytemp[j] = y[j]
xtemp[i] = x[i]
else:
for i in range(len(x)):
xtemp[i] += hx
g[:,i,0] = (f(xtemp,y+hy)-fx[i]) - (fy - f0)
xtemp[i] = x[i]
else:
if hasattr(y,'__iter__'):
for j in range(len(y)):
ytemp[j] += hy
g[:,0,j] = (f(x+hy,ytemp)-fx) - (fy[j] - f0)
ytemp[j] = y[j]
else:
g[:,0,0] = (f(x+hx,y+hy)-fx) - (fy - f0)
g *= 1.0/(hx*hy)
return g

[docs]def check_derivatives(f,x,h=1e-6,rtol=1e-2,atol=1e-3):

If f is an ADFunctionCall, the derivatives of all terminal values are
checked, and x must be a dict matching terminal values to their settings.
Example::

f = var('x')+var('y')**2+10
check_derivatives(f,{'x':3,'y':2})

If f is an ADFunctionInterface, x is list of the function's arguments.
All derivative and jvp functions are tested.

The derivatives must match the finite differences result by relative
tolerance rtol and absolute tolerance atol (see np.allclose for details).
Otherwise, an AssertionError is raised.
"""
#check w.r.t. all arguments
if not isinstance(x,dict):
raise ValueError("If a function call is provided to check_derivatives, then x must be a dict of keyword arguments")
if len(x)==0:
#no derivatives
return True
def replace_eval(y,arg):
oldx = x[arg]
x[arg] = y
res = f.eval(**x)
x[arg] = oldx
return res
for k,v in x.items():
try:
g = f.derivative(k,**x)
except NotImplementedError:
continue
g_fd = finite_differences(lambda y:replace_eval(y,k),v,h)
if g is 0:
g = np.zeros(g_fd.shape)
elif g.shape != g_fd.shape:
raise AssertionError("derivative of %s w.r.t. %s did not produce the correct shape: %s vs %s"%(str(f),k,g.shape,g_fd.shape))
if not np.allclose(g_fd,g,rtol,atol):
print("check_derivative",f,"failed with args",x,"@ argument",k)
print("Finite differences:",g_fd)
print("Derivative:",g)
raise AssertionError("derivative of %s w.r.t. %s has an error of size %f"%(str(f.func),k,np.linalg.norm(g-g_fd)))
assert f.n_args() < 0 or f.n_args() == len(x),"Invalid number of arguments provided to %s: %d != %d"%(str(f),len(x),f.n_args())
has_gen_derivative = True
for i,v in enumerate(x):
has_derivative = False
try:
g = f.derivative(i,*x)
has_derivative = True
g_fd = finite_differences(lambda y:f.eval(*(x[:i]+[y]+x[i+1:])),v,h)
if g is 0:
g = np.zeros(g_fd.shape)
elif g.shape != g_fd.shape:
raise AssertionError("derivative of %s w.r.t. %s did not produce the correct shape: %s vs %s"%(str(f),f.argname(i),g.shape,g_fd.shape))
if not np.allclose(g_fd,g,rtol,atol):
print("check_derivative",f,"failed with args",x,"@ argument",i)
print("Finite differences:",g_fd)
print("Derivative:",g)
raise AssertionError("derivative of %s w.r.t. %s has an error of size %f"%(str(f),f.argname(i),np.linalg.norm(g-g_fd)))
except NotImplementedError:
pass

has_jvp = False
try:
for j in range(_size(v)):
dv = np.zeros(_size(v))
dv[j] = 1
g = f.jvp(i,dv,*x)
has_jvp = True
if _scalar(v):
dv = dv[0]
g_fd = finite_differences(lambda t:f.eval(*(x[:i]+[v+dv*t]+x[i+1:])),0,h)[:,0]
if g is 0:
g = np.zeros(g_fd.shape)
if _scalar(g):
g = np.array([g])
elif g.shape != g_fd.shape:
raise AssertionError("Jacobian-vector product of %s w.r.t. %s did not produce the correct shape: %s vs %s"%(str(f),f.argname(i),g.shape,g_fd.shape))
if not np.allclose(g_fd,g,rtol,atol):
print("check_derivative",f,"failed with args",x,"@ argument",i)
print("Chosen direction",dv)
print("jvp Finite differences:",g_fd)
print("jvp:",g)
raise AssertionError("Jacobian-vector product of %s w.r.t. %s has an error of size %f"%(str(f),f.argname(i),np.linalg.norm(g-g_fd)))
for j in range(_size(v)):
#test with scaling
dv = np.zeros(_size(v))
dv[j] = np.random.uniform(-0.9,-0.05)
g = f.jvp(i,dv,*x)
has_jvp = True
if _scalar(v):
dv = dv[0]
g_fd = finite_differences(lambda t:f.eval(*(x[:i]+[v+dv*t]+x[i+1:])),0,h)[:,0]
if g is 0:
g = np.zeros(g_fd.shape)
if _scalar(g):
g = np.array([g])
elif g.shape != g_fd.shape:
raise AssertionError("Jacobian-vector product of %s w.r.t. %s did not produce the correct shape: %s vs %s"%(str(f),f.argname(i),g.shape,g_fd.shape))
if not np.allclose(g_fd,g,rtol,atol):
print("check_derivative",f,"failed with args",x,"@ argument",i)
print("Chosen direction",dv)
print("jvp Finite differences:",g_fd)
print("jvp:",g)
raise AssertionError("Jacobian-vector product of %s w.r.t. %s has an error of size %f"%(str(f),f.argname(i),np.linalg.norm(g-g_fd)))

dv = np.random.uniform(-1,-1,_size(v))
g = f.jvp(i,dv,*x)
has_jvp = True
if _scalar(v):
dv = dv[0]
g_fd = finite_differences(lambda t:f.eval(*(x[:i]+[v+dv*t]+x[i+1:])),0,h)[:,0]
if g is 0:
g = np.zeros(g_fd.shape)
if _scalar(g):
g = np.array([g])
elif g.shape != g_fd.shape:
raise AssertionError("Jacobian-vector product of %s w.r.t. %s did not produce the correct shape: %s vs %s"%(str(f),f.argname(i),g.shape,g_fd.shape))
if not np.allclose(g_fd,g,rtol,atol):
print("check_derivative",f,"failed with args",x,"@ argument",i)
print("Chosen direction",dv)
print("jvp Finite differences:",g_fd)
print("jvp:",g)
raise AssertionError("Jacobian-vector product of %s w.r.t. %s has an error of size %f"%(str(f),f.argname(i),np.linalg.norm(g-g_fd)))
except NotImplementedError:
pass
if not (has_derivative or has_jvp):
print("check_derivative: function",f,"has no derivative or jvp function defined for arg",i)

try:
H = f.gen_derivative([i,i],*x)
H_fd = finite_differences_hessian(lambda y:f.eval(*(x[:i]+[y]+x[i+1:])),v,h)
if H is 0:
H = np.zeros(H_fd.shape)
elif H.shape != H_fd.shape:
raise AssertionError("Hessian of %s w.r.t. %s did not produce the correct shape: %s vs %s"%(str(f),f.argname(i),H.shape,H_fd.shape))
if not np.allclose(H_fd,H,rtol,atol):
print("check_derivative",f,"failed with args",x,"@ argument",i)
print("Finite differences:",H_fd)
print("gen_derivatives Hessian:",H)
raise AssertionError("gen_derivative of %s w.r.t. %s has an error of size %f"%(str(f),f.argname(i),np.linalg.norm(H-H_fd)))
except NotImplementedError:
has_gen_derivative = False
pass
if has_gen_derivative:
for i,v in enumerate(x):
for j,w in list(enumerate(x))[i+1:]:
H = f.gen_derivative([i,j],*x)
H_fd = finite_differences_hessian2(lambda y,z:f.eval(*(x[:i]+[y]+x[i+1:j]+[z]+x[j+1:])),v,w,h)
if H is 0:
H = np.zeros(H_fd.shape)
elif H.shape != H_fd.shape:
raise AssertionError("Hessian of %s w.r.t. %s,%s did not produce the correct shape: %s vs %s"%(str(f),f.argname(i),f.argname(j),H.shape,H_fd.shape))
H2 = f.gen_derivative([j,i],*x)
H2_fd = finite_differences_hessian2(lambda z,y:f.eval(*(x[:i]+[y]+x[i+1:j]+[z]+x[j+1:])),w,v,h)
if H2 is 0:
H2 = np.zeros(H2_fd.shape)
elif H2.shape != H2_fd.shape:
raise AssertionError("Hessian of %s w.r.t. %s,%s did not produce the correct shape: %s vs %s"%(str(f),f.argname(j),f.argname(i),H.shape,H_fd.shape))
else:
return True

[docs]def eval_multiple(exprs,**kwargs):
"""Given a list of expressions, and keyword arguments that set variable
values, returns a list of the evaluations of the expressions.

This can leverage common subexpressions in exprs to speed up running times
compared to multiple eval() calls.
"""
for e in exprs:
e._clear_eval_cache()
return [e._eval([],kwargs) for e in exprs]

[docs]def derivative_multiple(exprs,arg,**kwargs):
"""Given a list of expressions, the argument to which you'd like to take
the derivatives, and keyword arguments that set variable values, returns
a list of the derivatives of the expressions.

This can leverage common subexpressions in exprs to speed up running times
compared to multiple derivative() calls.
"""
for e in exprs:
e._clear_deriv_cache()
if any(e._eval_context is None or not _context_equal(e._eval_context,kwargs) for e in exprs):
eval_multiple(exprs,**kwargs)
res = [e._deriv([],arg,kwargs) for e in exprs]
for i,der in enumerate(res):
if der is 0:
res[i] = np.zeros((_size(exprs[i]._eval_result),_size(kwargs[arg])))
if len(der.shape) == 1:
#deriv w.r.t. scalar
res[i] = der.reshape((der.shape[0],1))
return res

def _size(a):
try:
return a.shape[0]
except AttributeError:
if hasattr(a,'__iter__'):
raise ValueError("Can't pass a list as an argument. Try wrapping args with np.array.")
return 1
except IndexError:
#a numpy 1-parameter array
return 1

def _scalar(a):
return np.ndim(a)==0

def _context_equal(a,b):
for k,v in a.items():
try:
if v is not b[k] and not np.array_equal(v,b[k]):
return False
except KeyError:
return False
for k in b.keys():
if k not in a:
return False
return True

def _zero_derivative(nout,arg,*args):
if hasattr(arg,'__iter__'):
return np.zeros([nout]+[_size(args[a]) for a in arg])
return np.zeros([nout,_size(args[arg])])

def __str__(self):

def n_args(self):
return 2

def n_in(self,arg):
return -1

def n_out(self):
return -1

def eval(self,x,y):
return x + y

def derivative(self,arg,*args):
size = max(_size(a) for a in args)
assert 0 <= arg <= 1
if _scalar(args[arg]):
return np.ones((size,1))
return np.eye(size)

def jvp(self,arg,darg,*args):
assert 0 <= arg <= 1
size = max(_size(a) for a in args)
if _scalar(args[arg]):
assert darg.shape[0] == 1
if size == 1:
return darg
else:
return np.hstack([darg]*size)
else:
assert darg.shape[0] == size
return darg

def gen_derivative(self,arg,*args):
if len(arg) == 1:
return self.derivative(arg[0],*args)
return 0

def __str__(self):
return 'sub'

def n_args(self):
return 2

def n_in(self,arg):
return -1

def n_out(self):
return -1

def eval(self,x,y):
return x - y

def derivative(self,arg,*args):
size = max(_size(a) for a in args)
assert 0 <= arg <= 1
sign = 1 if arg == 0 else -1
if _scalar(args[arg]):
return sign*np.ones((size,1))
return sign*np.eye(size)

def jvp(self,arg,darg,*args):
assert 0 <= arg <= 1
size = max(_size(a) for a in args)
sign = 1 if arg == 0 else -1
if _scalar(args[arg]):
assert darg.shape[0] == 1
if size == 1:
return sign*darg
else:
return sign*np.hstack([darg]*size)
else:
assert darg.shape[0] == size
return sign*darg

def gen_derivative(self,arg,*args):
if len(arg) > 1:
return 0
return self.derivative(arg[0],*args)

def __str__(self):
return 'mul'

def n_args(self):
return 2

def n_in(self,arg):
return -1

def n_out(self):
return -1

def eval(self,x,y):
return x*y

def derivative(self,arg,*args):
assert 0 <= arg <= 1
x = args[arg]
y = args[1-arg]
if _scalar(x):
res = np.asarray(y)
return res.reshape((_size(y),1))
if _scalar(y):
return np.diag([y]*_size(x))
return np.diag(np.asarray(y))

def jvp(self,arg,darg,*args):
return args[1-arg]*darg

def gen_derivative(self,arg,*args):
if len(arg) == 1:
return self.derivative(arg[0],*args)
elif len(arg) == 2:
i,j = arg
if i == j:
return 0
res = args[i]*args[j]
hess_shape = (_size(res),_size(args[i]),_size(args[j]))
res = np.zeros(hess_shape)
if _scalar(args[i]):
if _scalar(args[j]):
res[0,0,0] = 1
else:
for i in range(hess_shape[2]):
res[i,0,i] = 1
else:
if _scalar(args[j]):
for i in range(hess_shape[1]):
res[i,i,0] = 1
else:
for i in range(hess_shape[1]):
res[i,i,i] = 1
return res
return 0

def __str__(self):
return 'div'

def n_args(self):
return 2

def n_in(self,arg):
return -1

def n_out(self):
return -1

def eval(self,x,y):
return x/y

def derivative(self,arg,x,y):
if arg == 0:
diag = 1.0/y
if _scalar(x):
return diag.reshape((_size(y),1))
if _scalar(y):
diag = [diag]*_size(x)
return np.diag(diag)
else:
diag = -x/y**2
if _scalar(y):
return diag.reshape((_size(x),1))
return np.diag(-x/y**2)

def jvp(self,arg,darg,x,y):
if arg == 0:
return darg/y
else:
return -(x*darg)/y**2

def gen_derivative(self,arg,x,y):
if len(arg) == 1:
return self.derivative(arg[0],x,y)
elif len(arg) == 2:
i,j = arg
if i == 0 and j == 0:
return 0
if i == j:
#double derivative w.r.t. y
m = max(_size(x),_size(y))
n = _size(y)
hess_shape = (m,n,n)
res = np.zeros(hess_shape)
xarr = [x]*n if _scalar(x) else x
if _scalar(y):
for i in range(m):
res[i,0,0] = 2*xarr[i]/y**3
else:
for i in range(n):
res[i,i,i] = 2*xarr[i]/y[i]**3
return res
else:
swap = (i == 1)
#double derivative d^2/dxy or d^2/dyx
m = max(_size(x),_size(y))
hess_shape = (m,_size(x),_size(y))
res = np.zeros(hess_shape)
if _scalar(y):
for i in range(hess_shape[1]):
res[i,i,0] = -1.0/y**2
elif _scalar(x):
for i in range(hess_shape[2]):
res[i,0,i] = -1.0/y[i]**2
else:
for i in range(hess_shape[2]):
res[i,i,i] = -1.0/y[i]**2
if swap:
return np.swapaxes(res,1,2)
else:
return res
raise NotImplementedError()

def __str__(self):
return 'neg'

def n_args(self):
return 1

def n_in(self,arg):
return -1

def n_out(self):
return -1

def eval(self,x):
return -x

def derivative(self,arg,x):
size = _size(x)
return -np.eye(size)

def jvp(self,arg,dx,x):
return -dx

def gen_derivative(self,arg,x):
if len(arg) == 1:
return self.derivative(arg[0],x)
return 0

def __str__(self):
return 'sum'

def n_args(self):
return -1

def n_in(self,arg):
return -1

def n_out(self):
return -1

def eval(self,*args):
if len(args) == 1:
return np.sum(args[0])
return sum(args)

def derivative(self,arg,*args):
if len(args) == 1:
return np.ones(_size(args[0]))[np.newaxis,:]
size = max(_size(a) for a in args)
if _scalar(args[arg]):
return np.ones((size,1))
return np.eye(size)

def jvp(self,arg,darg,*args):
if len(args) == 1:
return np.sum(darg)
size = max(_size(a) for a in args)
if _scalar(args[arg]):
if size == 1:
return darg
else:
return np.hstack([darg]*size)
else:
return darg

def gen_derivative(self,arg,*args):
if len(arg) == 1:
return self.derivative(arg[0],*args)
return 0

def __str__(self):
return 'pow'

def argname(self,arg):
return ['base','exp'][arg]

def n_args(self):
return 2

def n_in(self,arg):
return -1

def n_out(self):
return -1

def eval(self,base,exp):
return np.power(base,exp)

def derivative(self,arg,base,exp):
size = max(_size(base),_size(exp))
if arg == 0:
diag = exp*np.power(base,exp-1)
if _scalar(base):
return diag.reshape((_size(diag),1))
return np.diag(diag)
else:
with np.errstate(divide='ignore'):
diag = np.log(base)
if _scalar(base):
if base == 0:
diag = np.array(0.0)
else:
diag[base==0] = 0
diag = diag*np.power(base,exp)
if _scalar(exp):
return diag.reshape((_size(diag),1))
return np.diag(diag)

def jvp(self,arg,darg,base,exp):
if arg == 0:
return exp*np.power(base,exp-1)*darg
else:
with np.errstate(divide='ignore'):
scale = np.log(base)
if _scalar(base):
if base == 0:
scale = 0.0
else:
scale[base==0] = 0
return scale*np.power(base,exp)*darg

def gen_derivative(self,arg,*args):
if len(arg)==1:
return self.derivative(arg[0],*args)
elif len(arg)==2:
base,exp = args
n = max(_size(base),_size(exp))
hess_shape = (n,_size(args[arg[0]]),_size(args[arg[1]]))
res = np.zeros(hess_shape)
if arg[0]==0 and arg[1]==0:
diag = exp*(exp-1)*np.power(base,exp-2)
if _scalar(base):
return diag.reshape((1,1,1))
else:
for i in range(n):
res[i,i,i] = diag[i]
elif arg[0]==1 and arg[1]==1:
#first deriv is log(base)*base**exp
#second deriv is log(base)**2*base**exp
with np.errstate(divide='ignore'):
scale = np.log(base)**2
if _scalar(base):
if base == 0:
scale = 0.0
else:
scale[base==0] = 0
diag = scale*np.power(base,exp)
if _scalar(exp):
return diag.reshape((n,1,1))
else:
for i in range(n):
res[i,i,i] = diag[i]
else:
#d/dexp log(base)*base**exp
#d^2/dexp dbase is (1 + log(base)*exp)*base**(exp-1)
with np.errstate(divide='ignore'):
scale = np.log(base)
if _scalar(base):
if base == 0:
scale = 0.0
else:
scale[base==0] = 0
diag = (1 + scale*exp)*np.power(base,exp-1)
if arg[0] == 1:
for i in range(n):
j = 0 if hess_shape[1] == 1 else i
k = 0 if hess_shape[2] == 1 else i
res[i,j,k] = diag[i]
return res
raise NotImplementedError()

def __str__(self):
return 'abs'

def n_args(self):
return 1

def n_in(self,arg):
return -1

def n_out(self):
return -1

def eval(self,x):
return np.abs(x)

def derivative(self,arg,x):
assert arg == 0
size = _size(x)
if _scalar(x):
return np.sign(x).reshape((1,1))
return np.diag(np.sign(x))

def jvp(self,arg,dx,x):
return dx*np.sign(x)

def gen_derivative(self,arg,x):
if len(arg) == 1:
return self.derivative(arg[0],x)
return 0

def __init__(self,indices):
self.indices = indices

def __str__(self):
return 'getitem['+str(self.indices)+']'

def n_args(self):
return 1

def n_in(self,arg):
return -1

def n_out(self,argval=None):
if argval is None:
return -1
if isinstance(self.indices,slice):
start, stop, step = self.indices.indices(len(argval))
return (stop-start)//step
elif hasattr(self.indices,'__iter__'):
return len(self.indices)
else:
return 1

def eval(self,x):
return x[self.indices]

def derivative(self,arg,x):
assert arg == 0
size = _size(x)
if isinstance(self.indices,slice):
start, stop, step = self.indices.indices(size)
nres = (stop-start)//step
res = np.zeros((nres,size))
for i in range(nres):
j = start+i*step
res[i,j] = 1
return res
elif hasattr(self.indices,'__iter__'):
res = np.zeros((len(self.indices),size))
for i,j in enumerate(self.indices):
res[i,j] = 1
return res
else:
res = np.zeros((1,size))
res[0,self.indices] = 1
return res

def jvp(self,arg,dx,x):
return dx[self.indices]

def gen_derivative(self,arg,x):
if len(arg) == 1:
return self.derivative(arg[0],x)
return 0

def __init__(self,indices):
self.indices = indices

def __str__(self):
return 'setitem['+str(self.indices)+']'

def n_args(self):
return 2

def n_in(self,arg):
if arg==0: return -1
if isinstance(self.indices,slice):
if self.indices.stop is not None:
start, stop, step = self.indices.start,self.indices.stop,self.indices.step
if self.indices.step >= 0 and self.indices.stop >= 0:
return (stop-start)//step
return -1
elif hasattr(self.indices,'__iter__'):
return len(self.indices)
else:
return 1

def n_out(self):
return -1

def eval(self,x,v):
xcopy = np.copy(x)
xcopy[self.indices] = v
return xcopy

def derivative(self,arg,x,v):
if arg==0:
diag = np.ones(x.shape)
diag[self.indices] = 0
return np.diag(diag)
else:
res = np.zeros((_size(x),_size(v)))
if isinstance(self.indices,slice):
start, stop, step = self.indices.indices(res.shape[0])
nres = (stop-start)//step
for i in range(nres):
j = start+i*step
res[j,i] = 1
return res
elif hasattr(self.indices,'__iter__'):
for i,j in enumerate(self.indices):
res[j,i] = 1
return res
else:
res[self.indices,0] = 1
return res

def jvp(self,arg,darg,x,v):
if arg==0:
res = np.copy(darg)
res[self.indices] = 0
return res
else:
res = np.zeros(x.shape)
res[self.indices] = darg
return res

def gen_derivative(self,arg,x,v):
if len(arg) == 1:
return self.derivative(arg[0],x,v)
return 0

def __str__(self):
return 'stack'

def n_args(self):
return -1

def n_in(self,arg):
return -1

def n_out(self):
return -1

def eval(self,*args):
return np.hstack(args)

def derivative(self,arg,*args):
sizes = [_size(a) for a in args]
totalsize = sum(sizes)
res = np.zeros((totalsize,sizes[arg]))
indices = np.cumsum([0]+sizes)
base = indices[arg]
for i in range(sizes[arg]):
res[base+i,i] = 1
return res

def jvp(self,arg,darg,*args):
sizes = [_size(a) for a in args]
totalsize = sum(sizes)
indices = np.cumsum([0]+sizes)
res = np.zeros(totalsize)
res[indices[arg]:indices[arg]+sizes[arg]] = darg
return res

def gen_derivative(self,arg,*args):
if len(arg) == 1:
return self.derivative(arg[0],*args)
return 0

def __str__(self):
return 'minimum'

def n_args(self):
return -1

def n_in(self,arg):
return -1

def n_out(self):
return -1

def eval(self,*args):
if len(args) == 0:
raise ValueError("Can't take a minimum with 0 arguments")
if len(args) == 1:
assert not _scalar(args[0]),"Can't take a minimum over a single scalar"
return args[0].min()
res = np.minimum(args[0],args[1])
for i in range(2,len(args)):
res = np.minimum(res,args[i])
return res

def derivative(self,arg,*args):
if len(args)==1:
res = np.zeros((1,_size(args[0])))
res[0,args[0].argmax()] = 1
return res
sizes = [_size(a) for a in args]
size = max(sizes)
amin = np.zeros(size,dtype=int)
if _scalar(args[0]) and _scalar(args[1]):
if args[1] < args[0]:
amin[:] = 1
else:
amin[args[1] < args[0]] = 1
res = np.minimum(args[0],args[1])
for i in range(2,len(args)):
if _scalar(res) and _scalar(args[i]):
if args[i] < res:
amin[:] = i
else:
amin[args[i] < res] = i
res = np.minimum(res,args[i])
J = np.zeros((size,sizes[arg]))
if sizes[arg] == 1:
for i,index in enumerate(amin):
if index == arg:
J[i,0] = 1
else:
for i,index in enumerate(amin):
if index == arg:
J[i,i] = 1
return J

def jvp(self,arg,darg,*args):
if len(args)==1:
return darg[args[0].argmin()]
sizes = [_size(a) for a in args]
size = max(sizes)
amin = np.zeros(size,dtype=int)
if _scalar(args[0]) and _scalar(args[1]):
if args[1] < args[0]:
amin[:] = 1
else:
amin[args[1] < args[0]] = 1
res = np.minimum(args[0],args[1])
for i in range(2,len(args)):
if _scalar(res) and _scalar(args[i]):
if args[i] < res:
amin[:] = i
else:
amin[args[i] < res] = i
res = np.minimum(res,args[i])
Jdx = np.zeros(size)
if sizes[arg] == 1:
for i,index in enumerate(amin):
if index == arg:
Jdx[i] = darg[0]
else:
for i,index in enumerate(amin):
if index == arg:
Jdx[i] = darg[i]
return Jdx

def gen_derivative(self,arg,*args):
if len(arg) == 1:
return self.derivative(arg[0],*args)
return 0

def __str__(self):
return 'maximum'

def n_args(self):
return -1

def n_in(self,arg):
return -1

def n_out(self):
return -1

def eval(self,*args):
if len(args) == 0:
raise ValueError("Can't take a maximum with 0 arguments")
if len(args) == 1:
assert not _scalar(args[0]),"Can't take a minimum over a single scalar"
return args[0].max()
res = np.maximum(args[0],args[1])
for i in range(2,len(args)):
res = np.maximum(res,args[i])
return res

def derivative(self,arg,*args):
if len(args)==1:
res = np.zeros((1,_size(args[0])))
res[0,args[0].argmax()] = 1
return res
sizes = [_size(a) for a in args]
size = max(sizes)
amin = np.zeros(size,dtype=int)
if _scalar(args[0]) and _scalar(args[1]):
if args[1] > args[0]:
amin[:] = 1
else:
amin[args[1] > args[0]] = 1
res = np.maximum(args[0],args[1])
for i in range(2,len(args)):
if _scalar(res) and _scalar(args[i]):
if args[i] > res:
amin[:] = i
else:
amin[args[i] > res] = i
res = np.maximum(res,args[i])
J = np.zeros((size,sizes[arg]))
if sizes[arg] == 1:
for i,index in enumerate(amin):
if index == arg:
J[i,0] = 1
else:
for i,index in enumerate(amin):
if index == arg:
J[i,i] = 1
return J

def jvp(self,arg,darg,*args):
if len(args)==1:
return darg[args[0].argmax()]
sizes = [_size(a) for a in args]
size = max(sizes)
amin = np.zeros(size,dtype=int)
if _scalar(args[0]) and _scalar(args[1]):
if args[1] > args[0]:
amin[:] = 1
else:
amin[args[1] > args[0]] = 1
res = np.maximum(args[0],args[1])
for i in range(2,len(args)):
if _scalar(res) and _scalar(args[i]):
if args[i] > res:
amin[:] = i
else:
amin[args[i] > res] = i
res = np.maximum(res,args[i])
Jdx = np.zeros(size)
if sizes[arg] == 1:
for i,index in enumerate(amin):
if index == arg:
Jdx[i] = darg[0]
else:
for i,index in enumerate(amin):
if index == arg:
Jdx[i] = darg[i]
return Jdx

def gen_derivative(self,arg,*args):
if len(arg) == 1:
return self.derivative(arg[0],*args)
return 0

def __str__(self):
return 'cond'

def argname(self,arg):
return ['pred','posval','negval'][arg]

def n_args(self):
return 3

def n_in(self,arg):
return -1

def n_out(self):
return -1

def eval(self,pred,trueval,falseval):
if _scalar(pred):
if pred > 0:
return trueval
return falseval
else:
assert len(pred)==len(trueval)
assert len(pred)==len(falseval)
true_inds = np.where(pred > 0)[0]
res = falseval.copy()
res[true_inds]=trueval[true_inds]
return res

def jvp(self,arg,darg,pred,trueval,falseval):
if arg==0:
return 0
if arg==1:
if _scalar(pred):
if pred > 0:
return darg
return 0
else:
res = np.zeros(trueval.shape)
true_inds = np.where(pred > 0)[0]
res[true_inds] = darg[true_inds]
else:
if _scalar(pred):
if pred > 0:
return 0
return darg
else:
res = darg.copy()
true_inds = np.where(pred > 0)[0]
res[true_inds] = 0
return res

def __str__(self):
return 'cond3'

def argname(self,arg):
return ['pred','posval','zeroval','negval'][arg]

def n_args(self):
return 4

def n_in(self,arg):
return -1

def n_out(self):
return -1

def eval(self,pred,trueval,zeroval,falseval):
if _scalar(pred):
if pred > 0:
return trueval
elif pred == 0:
return zeroval
else:
return falseval
else:
assert len(pred)==len(trueval)
assert len(pred)==len(zeroval)
assert len(pred)==len(falseval)
res = falseval.copy()
res[pred>0]=trueval[pred>0]
res[pred==0]=zeroval[pred==0]
return res

def jvp(self,arg,darg,pred,trueval,zeroval,falseval):
if arg==0:
return 0
if arg==1:
if _scalar(pred):
if pred > 0:
return darg
return 0
else:
res = np.zeros(trueval.shape)
res[pred > 0] = darg[pred > 0]
elif arg==2:
if _scalar(pred):
if pred == 0:
return darg
return 0
else:
res = np.zeros(trueval.shape)
res[pred == 0] = darg[pred == 0]
else:
if _scalar(pred):
if pred >= 0:
return 0
return darg
else:
res = darg.copy()
res[pred >= 0] = 0
return res

"""Autodiff'ed function comparable to the addition operator +. Applied
automatically when the + operator is called on an ADTerminal or an

"""Autodiff'ed function comparable to the subtraction operator -. Applied
automatically when the - operator is called on an ADTerminal or an

"""Autodiff'ed function comparable to the multiplication operator \*. Applied
automatically when the * operator is called on an ADTerminal or an
ADFunctionCall.  Multiplication of vectors is performed element-wise."""

"""Autodiff'ed function comparable to the division operator /. Applied
automatically when the / operator is called on an ADTerminal or an
ADFunctionCall.  Divison of vectors is performed element-wise."""

"""Autodiff'ed function comparable to the negation operator -. Applied
automatically when the - operator is called on an ADTerminal or an

"""Autodiff'ed function comparable to sum.  It acts like np.sum if only one
item is provided.  Otherwise, it acts like the builtin sum."""

"""Autodiff'ed function comparable to np.power. Applied automatically when the

"""Autodiff'ed function comparable to abs. Applied automatically when abs is

"""Autodiff'ed function comparable to np.stack."""

[docs]def getitem(a,index):
"""Creates an Autodiff'ed function comparable to the [] operator. Note that
the index argument cannot be a variable."""

[docs]def setitem(x,index,v):
"""Creates an Autodiff'ed function similar to x[index] = v, return x
but without modifying x.  The index argument cannot be a variable."""

"""Autodiff'ed function comparable to np.minimum. It acts like ndarray.min if
only one item is provided, and otherwise it acts like np.minimum. If more than
2 items are provided, they are applied elementwise and sequentially.
"""

"""Autodiff'ed function comparable to np.maximum. It acts like ndarray.max if
only one item is provided, and otherwise it acts like np.maximum. If more than
2 items are provided, they are applied elementwise and sequentially.
"""

"""Autodiff'ed function that performs a conditional (pred,trueval,falseval).
It returns trueval if pred > 0 and falseval otherwise.   If pred is an array,
then it must have the same size as trueval and falseval, and performs the
conditioning element-wise.
"""

`