```"""se3 module AD functions:

Note that each function takes a 12 element vector which is the concatenation
of the R and t terms in an se3 element.  In other words, to convert from an
se3 element T to an entry that can be properly eval'ed, use the code
``np.array(T[0] + T[1])``.  Alternatively, you can use the ``join`` function

To extract out the rotation / translation from an autodiff-ed function, use

====================  =============  ====================================
Function              Derivative     Notes
====================  =============  ====================================
join                  Y              Creates a 12-element vector from R,t
identity              N/A            A constant function with no args
apply                 1
apply_rotation        1
mul                   N
inv                   N
rotation              Y
from_rotation         Y
translation           Y
from_translation      Y
distance              N
interpolate           N
error                 N
====================  =============  ====================================

.. note::
To do gradient descent on se3 elements, you will need to either project to
the se3 submanifold or use a non-redundant representation.  The rotation

convert from a 6-element vector x (first 3 elements encoding the rotation
vector, final 3 elements encoding the translation) to a se3 element.

Module contents
~~~~~~~~~~~~~~~

.. autosummary::
join
identity
apply
apply_rotation
mul
inv
rotation
from_rotation
translation
from_translation
distance
interpolate
error

"""

import numpy as np
from .. import so3,se3

SIZE = 12
"""Constant giving the dimension of an se3_ad element"""

def _rotation(T):
return T[:9]
def _translation(T):
return T[9:]
def _split(T):
return (T[:9],T[9:])
def _join(T):
return np.hstack(T)

[docs]def to_klampt(T):
"""Converts an autodiff se3 representation as a length-12 numpy array to
the native Klampt representation (R,t)."""
return (T[:9].tolist(),T[9:].tolist())

[docs]def from_klampt(T):
"""Converts a native Klampt se3 representation (R,t) to the length-12 numpy
array representation used in autodiff."""
return _join(T)

join = function(lambda R,t:np.hstack((R,t)),'se3.join',[9,3],12,
derivative=[lambda R,t:np.vstack((np.eye(9),np.zeros((3,9)))),lambda R,t:np.vstack((np.zeros((9,3)),np.eye(3)))],
jvp=[lambda dR,R,t:np.hstack((dR,np.zeros(3))),lambda dt,R,t:np.hstack((np.zeros(9),dt))],order=1)
"""Autodiff function to join SO(3) rotation matrix and 3D translation vector
into an SE(3) 12-element vector. All derivatives are implemented."""

rotation = function(_rotation,'se3.rotation',(12,),9,
#derivative=[lambda T:np.hstack((np.eye(9),np.zeros((9,3))))],
jvp=[lambda dT,T:_rotation(dT)],order=1)
"""Autodiff function to extract the SO(3) rotation component of an SE(3)
12-element vector. All derivatives are implemented."""

translation = function(_translation,'se3.translation',(12,),3,
#derivative=[lambda T:np.hstack((np.zeros((3,9)),np.eye(3)))],
jvp=[lambda dT,T:_translation(dT)],order=1)
"""Autodiff function to extract the 3-D translation component of an SE(3)
12-element vector. All derivatives are implemented."""

identity = lambda:_join(se3.identity())
"""A 0 argument function that returns the SO(3) identity (just an alias to
so3.identity)"""

apply = function(lambda T,x:se3.apply(_split(T),x),'se3.apply',(12,3),3,['T','x'],
jvp=[lambda dT,T,x:np.array(so3.apply(_rotation(dT),x))+_translation(dT),lambda dx,T,x:np.array(so3.apply(_rotation(T),dx))],order=2)
"""Autodiff wrapper of se3.apply.  First derivatives are implemented."""

apply_rotation = function(lambda T,x:so3.apply(_rotation(T),x),'se3.apply_rotation',(12,3),3,['T','x'],
jvp=[lambda dT,T,x:np.array(so3.apply(_rotation(dT),x)),lambda dx,T,x:np.array(so3.apply(_rotation(T),dx))],order=2)
"""Autodiff wrapper of se3.apply_rotation.  First derivatives are implemented."""

mul = function(lambda T1,T2:_join(se3.mul(_split(T1),_split(T2))),'se3.mul',(12,12),12,['T1','T1'])
"""Autodiff wrapper of se3.mul."""

inv = function(lambda T:_join(se3.inv(_split(T))),'se3.inv',(12,),12)
"""Autodiff wrapper of se3.inv."""

from_rotation = function(lambda R:_join((R,np.zeros(3))),'se3.from_rotation',(9,),12,
jvp=[lambda dR,R:_join((R,np.zeros(3)))],order=1)
"""Autodiff wrapper of se3.from_rotation.  All derivatives are implemented."""

from_translation = function(lambda t:_join((np.eye(3).flatten(),t)),'se3.from_translation',(3,),12,
jvp=[lambda dt,t:_join((np.eye(3).flatten(),dt))],order=1)
"""Autodiff wrapper of se3.from_translation.  All derivatives are implemented."""

error = function(lambda T1,T2:se3.error(_split(T1),_split(T2)),'se3.error',(12,12),6,['T1','T2'])
"""Autodiff wrapper of se3.error."""

distance = function(lambda T1,T2:se3.distance(_split(T1),_split(T2)),'se3.distance',(12,12),1,['T1','T2'])
"""Autodiff wrapper of se3.distance."""

interpolate = function(lambda T1,T2,u:_join(se3.interpolate(_split(T1),_split(T2),u)),'se3.interpolate',(12,12,1),12,['Ta','Tb','u'])
"""Autodiff wrapper of se3.interpolate."""

```