Source code for klampt.math.autodiff.so3_ad

"""so3 module AD functions:

 ====================  ============= 
 Function              Derivative    
 ====================  ============= 
 identity              N/A
 apply                 Y
 mul                   Y
 inv                   Y
 trace                 Y
 rpy                   N
 from_rpy              1
 rotation_vector       1
 from_rotation_vector  1
 axis                  N
 angle                 1
 from_axis_angle       1
 quaternion            N
 from_quaternion       N
 distance              N
 error                 1
 cross_product         N
 diag                  Y
 deskew                Y
 canonical             N
 interpolate           N
 det                   N
 ====================  ============= 

All so3_ad elements are just 9-element numpy arrays, just like the so3 module.

.. note:
    To do gradient descent on so3 elements, you will need to either project to
    the so3 submanifold or use a non-redundant representation.  The rotation 
    vector representation (see rotation_vector/from_rotation_vector) is
    recommended.

Module contents
~~~~~~~~~~~~~~~

.. autosummary::
    identity
    apply
    mul
    inv
    trace
    rpy
    from_rpy
    rotation_vector
    from_rotation_vector
    axis
    angle
    from_axis_angle
    quaternion
    from_quaternion
    distance
    error
    cross_product
    diag
    deskew
    canonical
    interpolate
    det

"""

import numpy as np 
from .ad import ADFunctionInterface,function
from . import math_ad
from .. import vectorops,so3
import math

SIZE = 9
"""Constant giving the dimension of an so3_ad element"""

identity = lambda: np.array(so3.identity())
"""A 0 argument function that returns the SO(3) identity (just an alias to
so3.identity)"""

inv = function(so3.inv,'so3.inv',(9,),9,
        jvp=[lambda dR,R:so3.inv(dR)],order=1)
"""Autodiff'ed version of so3.identity. All derivatives are implemented."""

mul = function(so3.mul,'so3.mul',(9,9),9,['R1','R1'],
        jvp=[lambda dR1,R1,R2:so3.mul(dR1,R2),lambda dR2,R1,R2:so3.mul(R1,dR2)],order=2)
"""Autodiff'ed version of so3.mul. First derivatives are implemented."""

apply = function(so3.apply,'so3.apply',(9,3),3,['R','x'],
        jvp=[lambda dR,R,x:so3.apply(dR,x),lambda dx,R,x:so3.apply(R,dx)],order=2)
"""Autodiff'ed version of so3.apply. First derivatives are implemented."""

trace = function(so3.trace,'so3.trace',(9,),1,
        jvp=[lambda dR,R:so3.trace(dR)],order=1)
"""Autodiff'ed version of so3.trace. Allderivatives are implemented."""

def _from_rpy_jvp(drpy,rpy):
    roll,pitch,yaw = rpy
    droll,dpitch,dyaw = drpy
    Rx,Ry,Rz = from_axis_angle(((1,0,0),roll)),from_axis_angle(((0,1,0),pitch)),from_axis_angle(((0,0,1),yaw))
    wx = so3.cross_product([droll,0,0])
    wy = so3.cross_product([0,dpitch,0])
    wz = so3.cross_product([0,0,dyaw])
    Ryx = so3.mul(Ry,Rx)
    return vectorops.add(so3.mul(so3.mul(wz,Rz),Ryx), so3.mul(Rz,so3.mul(so3.mul(wy,Ryx))), so3.mul(so3.mul(Rz,Ry),so3.mul(wx,Rx)))
from_rpy = function(so3.from_rpy,'so3.from_rpy',(3,),9,
    jvp=[_from_rpy_jvp])
"""Autodiff'ed version of so3.from_rpy. First derivatives are implemented."""

rpy = function(so3.rpy,'so3.rpy',(9,),3)
"""Autodiff'ed version of so3.rpy."""

def _from_rotation_vector_jvp(dw,w):
    length = np.linalg.norm(w)
    dlength = math_ad.norm_jvp(dw,w)
    if length < 1e-7: return so3.cross_product(dw)
    axis = w/length
    daxis = math_ad._unit_jvp(dw,w)
    return _from_axis_angle_jvp_axis(daxis,axis,length) + _from_axis_angle_jvp_angle(dlength,length)
from_rotation_vector = function(so3.from_rotation_vector,'so3.from_rotation_vector',(3,),9,
    jvp = [_from_rotation_vector_jvp])
"""Autodiff'ed version of so3.from_rotation_vector. First derivatives are
implemented."""

def _rotation_vector_jvp(dR,R):
    theta = so3.angle(R)
    dtheta = _angle_jvp(dR,R)
    #normal
    scale = 0.5
    dscale = -0.5*dtheta
    if abs(theta) > 1e-5:
        s = math.sin(theta)
        ds = math.cos(theta)*dtheta
        scale = 0.5*theta/s
        dscale = 0.5*(dtheta/s - theta/s**2*ds)
    x = (dR[3+2]-dR[6+1]) * scale + (R[3+2]-R[6+1]) * dscale
    y = (dR[6+0]-dR[0+2]) * scale + (R[6+0]-R[0+2]) * dscale
    z = (dR[0+1]-dR[3+0]) * scale + (R[0+1]-R[3+0]) * dscale
    return [x,y,z]
rotation_vector = function(so3.rotation_vector,'so3.rotation_vector',(9,),3,
    jvp=[_rotation_vector_jvp])
"""Autodiff'ed version of so3.rotation_vector. First derivatives are
implemented."""

from_quaternion = function(so3.from_quaternion,'so3.from_quaternion',(4,),9,order=2)
"""Autodiff'ed version of so3.from_quaternion. First derivatives are
implemented."""

quaternion = function(so3.quaternion,'so3.quaternion',(9,),4)
"""Autodiff'ed version of so3.quaternion. First derivatives are implemented."""

def _from_axis_angle_derivative_axis(axis,angle):
    raise NotImplementedError()
def _from_axis_angle_derivative_angle(axis,angle):
    #m = cos(angle)*I + (1-cos(angle))axis*axis^T + sin(angle)[axis]
    R = so3.from_axis_angle((axis,angle))
    return np.array(so3.mul(R,so3.cross_product(axis)))[:,np.newaxis]
def _from_axis_angle_jvp_axis(daxis,axis,angle):
    #m = cos(angle)*I + (1-cos(angle))axis*axis^T + sin(angle)[axis]
    #dm/daxis*delta = (1-cos(angle))d/daxis(axis*axis^T) + sin(angle)d/daxis([axis]))*delta
    #so result is (1-cos(angle)) (delta x^T + x delta^T) + sin(angle)[delta]
    c = math.cos(angle)
    s = math.sin(angle)
    R = vectorops.mul(so3.cross_product(daxis),s)
    for i in range(3):
        for j in range(3):
            R[i*3+j] += (daxis[i]*axis[j] + axis[i]*daxis[j])*(1-c)
    return np.array(R)
def _from_axis_angle_jvp_angle(dangle,axis,angle):
    R = so3.from_axis_angle((axis,angle))
    return dangle*np.array(so3.mul(R,so3.cross_product(axis)))
from_axis_angle = function(lambda axis,angle:so3.from_axis_angle((axis,angle)),'so3.from_axis_angle',(3,1),9,['axis','angle'],
                    derivative=[_from_axis_angle_derivative_axis,_from_axis_angle_derivative_angle],
                    jvp=[_from_axis_angle_jvp_axis,_from_axis_angle_jvp_angle])
"""Autodiff'ed version of so3.from_axis_angle. First derivatives are
implemented."""

def _axis_jvp(dR,R):
    w = np.array(so3.rotation_vector(R))
    dw = _rotation_vector_jvp(dR,R)
    return math_ad._unit_jvp(dw,w)
axis = function(lambda R:vectorops.unit(so3.rotation_vector(R)),'axis',(9,),3,
    jvp=[_axis_jvp])
"""Autodiff'ed version of so3.axis_angle(R)[0]. First derivatives are
implemented."""

def _angle_jvp(dR,R):
    cosangle = (so3.trace(R) - 1)*0.5
    cosangle = max(min(cosangle,1.0),-1.0)
    if cosangle == 1:
        return vectorops.norm([dR[1],dR[2],dR[5]])
    #dangle / dR[0] = -1.0/sqrt(1-cosangle**2) * dcosangle/dR[0]
    dacos = -1.0/math.sqrt(1-cosangle**2)
    return so3.trace(dR)*0.5*dacos
angle = function(so3.angle,'so3.angle',(9,),1,jvp=[_angle_jvp])
"""Autodiff'ed version of so3.angle. First derivatives are
implemented."""

def _error_jvp_Ra(dRa,Ra,Rb):
    #error = so3.rotation_vector(so3.mul(Ra,so3.inv(Rb))
    #derror/dRa * delta = drotation_vector/dR(...)*dR/dRa* delta
    #d(Ra*Rb^-1 / dRa)*delta = delta*Rb^-1
    Rbinv = so3.inv(Rb)
    Rrel = so3.mul(Ra,Rbinv)
    dRrel = so3.mul(dRa,Rbinv)
    return _rotation_vector_jvp(dRrel,Rrel)
def _error_jvp_Rb(dRb,Ra,Rb):
    return -_error_jvp_Ra(dRb,Rb,Ra)
error = function(so3.error,'so3.error',(9,9),3,
    jvp=[_error_jvp_Ra,_error_jvp_Rb])
"""Autodiff'ed version of so3.error. First derivatives are
implemented."""

distance = function(so3.distance,'so3.distance',(9,9),1,['Ra','Rb'])
"""Autodiff'ed version of so3.distance."""

canonical = function(so3.canonical,'so3.canonical',(3,),9)
"""Autodiff'ed version of so3.canonical."""

cross_product = function(so3.cross_product,'so3.cross_product',(3,),9,['x'],
        jvp=[lambda dx,x:so3.cross_product(dx)],order=1)
"""Autodiff'ed version of so3.cross_product. All derivatives are implemented."""

diag = function(so3.diag,'so3.diag',(9,),3,['R'],
        jvp=[lambda dR,R:so3.diag(dR)],order=1)
"""Autodiff'ed version of so3.diag. All derivatives are implemented."""

deskew = function(so3.deskew,'so3.deskew',(9,),3,['R'],
        jvp=[lambda dR,R:so3.deskew(dR)],order=1)
"""Autodiff'ed version of so3.deskew. All derivatives are implemented."""

def _interpolate_deriv_u(Ra,Rb,u,du):
    x = so3.interpolate(Ra,Rb,u)
    ea = so3.cross_product(so3.error(Ra,x))
    eb = so3.cross_product(so3.error(Rb,x))
    return so3.mul(vectorops.sub(eb,ea),x)

interpolate = function(so3.interpolate,'so3.interpolate',(9,9,1),9,['Ra','Rb','u'],
    jvp=[None,None,_interpolate_deriv_u])
"""Autodiff'ed version of so3.interpolate."""

det = function(so3.det,'so3.det',(9,),1)
"""Autodiff'ed version of so3.det."""