Source code for klampt.math.autodiff.se3_ad

"""se3 module AD functions:

Note that each function takes a 12 element vector which is the concatenation
of the R and t terms in an se3 element.  In other words, to convert from an
se3 element T to an entry that can be properly eval'ed, use the code
``np.array(T[0] + T[1])``.  Alternatively, you can use the ``join`` function
``se3_ad.join(T[0],T[1])`` which is differentiable.

To extract out the rotation / translation from an autodiff-ed function, use
the ``se3_ad.rotation`` and ``se3_ad.translation`` functions.

 ====================  =============  ====================================
 Function              Derivative     Notes
 ====================  =============  ====================================
 join                  Y              Creates a 12-element vector from R,t
 identity              N/A            A constant function with no args
 apply                 1
 apply_rotation        1
 mul                   N
 inv                   N
 rotation              Y
 from_rotation         Y
 translation           Y
 from_translation      Y
 distance              N
 interpolate           N
 error                 N
 ====================  =============  ====================================
 
.. note::
    To do gradient descent on se3 elements, you will need to either project to
    the se3 submanifold or use a non-redundant representation.  The rotation 
    vector representation (see so3_ad.rotation_vector and
    so3_ad.from_rotation_vector) is recommended. 

    The code ``se3_ad.join(so3_ad.from_rotation_vector(x[0:3]),x[3:6]))`` will
    convert from a 6-element vector x (first 3 elements encoding the rotation
    vector, final 3 elements encoding the translation) to a se3 element.

Module contents
~~~~~~~~~~~~~~~

.. autosummary::
    join
    identity
    apply
    apply_rotation
    mul
    inv
    rotation
    from_rotation
    translation
    from_translation
    distance
    interpolate
    error

"""

import numpy as np 
from .ad import ADFunctionInterface,function
from .. import so3,se3

SIZE = 12
"""Constant giving the dimension of an se3_ad element"""

def _rotation(T):
    return T[:9]
def _translation(T):
    return T[9:]
def _split(T):
    return (T[:9],T[9:])
def _join(T):
    return np.hstack(T)

[docs]def to_klampt(T): """Converts an autodiff se3 representation as a length-12 numpy array to the native Klampt representation (R,t).""" return (T[:9].tolist(),T[9:].tolist())
[docs]def from_klampt(T): """Converts a native Klampt se3 representation (R,t) to the length-12 numpy array representation used in autodiff.""" return _join(T)
join = function(lambda R,t:np.hstack((R,t)),'se3.join',[9,3],12, derivative=[lambda R,t:np.vstack((np.eye(9),np.zeros((3,9)))),lambda R,t:np.vstack((np.zeros((9,3)),np.eye(3)))], jvp=[lambda dR,R,t:np.hstack((dR,np.zeros(3))),lambda dt,R,t:np.hstack((np.zeros(9),dt))],order=1) """Autodiff function to join SO(3) rotation matrix and 3D translation vector into an SE(3) 12-element vector. All derivatives are implemented.""" rotation = function(_rotation,'se3.rotation',(12,),9, #derivative=[lambda T:np.hstack((np.eye(9),np.zeros((9,3))))], jvp=[lambda dT,T:_rotation(dT)],order=1) """Autodiff function to extract the SO(3) rotation component of an SE(3) 12-element vector. All derivatives are implemented.""" translation = function(_translation,'se3.translation',(12,),3, #derivative=[lambda T:np.hstack((np.zeros((3,9)),np.eye(3)))], jvp=[lambda dT,T:_translation(dT)],order=1) """Autodiff function to extract the 3-D translation component of an SE(3) 12-element vector. All derivatives are implemented.""" identity = lambda:_join(se3.identity()) """A 0 argument function that returns the SO(3) identity (just an alias to so3.identity)""" apply = function(lambda T,x:se3.apply(_split(T),x),'se3.apply',(12,3),3,['T','x'], jvp=[lambda dT,T,x:np.array(so3.apply(_rotation(dT),x))+_translation(dT),lambda dx,T,x:np.array(so3.apply(_rotation(T),dx))],order=2) """Autodiff wrapper of se3.apply. First derivatives are implemented.""" apply_rotation = function(lambda T,x:so3.apply(_rotation(T),x),'se3.apply_rotation',(12,3),3,['T','x'], jvp=[lambda dT,T,x:np.array(so3.apply(_rotation(dT),x)),lambda dx,T,x:np.array(so3.apply(_rotation(T),dx))],order=2) """Autodiff wrapper of se3.apply_rotation. First derivatives are implemented.""" mul = function(lambda T1,T2:_join(se3.mul(_split(T1),_split(T2))),'se3.mul',(12,12),12,['T1','T1']) """Autodiff wrapper of se3.mul.""" inv = function(lambda T:_join(se3.inv(_split(T))),'se3.inv',(12,),12) """Autodiff wrapper of se3.inv.""" from_rotation = function(lambda R:_join((R,np.zeros(3))),'se3.from_rotation',(9,),12, jvp=[lambda dR,R:_join((R,np.zeros(3)))],order=1) """Autodiff wrapper of se3.from_rotation. All derivatives are implemented.""" from_translation = function(lambda t:_join((np.eye(3).flatten(),t)),'se3.from_translation',(3,),12, jvp=[lambda dt,t:_join((np.eye(3).flatten(),dt))],order=1) """Autodiff wrapper of se3.from_translation. All derivatives are implemented.""" error = function(lambda T1,T2:se3.error(_split(T1),_split(T2)),'se3.error',(12,12),6,['T1','T2']) """Autodiff wrapper of se3.error.""" distance = function(lambda T1,T2:se3.distance(_split(T1),_split(T2)),'se3.distance',(12,12),1,['T1','T2']) """Autodiff wrapper of se3.distance.""" interpolate = function(lambda T1,T2,u:_join(se3.interpolate(_split(T1),_split(T2),u)),'se3.interpolate',(12,12,1),12,['Ta','Tb','u']) """Autodiff wrapper of se3.interpolate."""