Source code for klampt.math.autodiff.pytorch

import as ad
import torch,numpy as np

[docs]class TorchModuleFunction(ad.ADFunctionInterface): """Converts a PyTorch function to a Klamp't autodiff function class.""" def __init__(self,module): self.module=module self._eval_params=[] torch.set_default_dtype(torch.float64) def __str__(self): return str(self.module)
[docs] def n_in(self,arg): return -1
[docs] def n_out(self): return -1
[docs] def eval(self,*args): self._eval_params=[] for a in args: if not isinstance(a,np.ndarray): a=np.array([a]) p=torch.Tensor(a) p.requires_grad_(True) self._eval_params.append(p) try: self._eval_result=torch.flatten(self.module(*self._eval_params)) #self._eval_result.forward() except Exception as e: print('Torch error: %s'%str(e)) return self._eval_result.detach().numpy()
[docs] def derivative(self,arg,*args): #lazily check if forward has been done before if not self._same_param(*args): self.eval(*args) rows=[] for i in range(self._eval_result.shape[0]): if self._eval_params[arg].grad is not None: self._eval_params[arg].grad.zero_() #this is a major performance penalty, torch does not support jacobian #we have to do it row by row self._eval_result[i].backward(retain_graph=True) rows.append(self._eval_params[arg].grad.detach().numpy().flatten()) return np.vstack(rows)
[docs] def jvp(self,arg,darg,*args): raise NotImplementedError('')
def _same_param(self,*args): if not hasattr(self,"_eval_params"): return False if len(self._eval_params)!=len(args): return False for p,a in zip(self._eval_params,args): pn = p.detach().numpy() if not isinstance(a,np.ndarray): a=np.array([a]) if pn.shape != a.shape: return False if (pn!=a).any(): return False return True
[docs]class ADModule(torch.autograd.Function): """Converts a Klamp't autodiff function call or function instance to a PyTorch Function. The class must be created with the terminal symbols corresponding to the PyTorch arguments to which this is called. """
[docs] @staticmethod def forward(ctx,func,terminals,*args): torch.set_default_dtype(torch.float64) if len(args)!=len(terminals): raise ValueError("Function %s expected to have %d arguments, instead got %d"%(str(func),len(terminals),len(args))) if isinstance(func,ad.ADFunctionCall): context={} for t,a in zip(terminals,args): context[]=a.detach().numpy() ret=func.eval(**context) elif isinstance(func,ad.ADFunctionInterface): context=[] for t,a in zip(terminals,args): context.append(a.detach().numpy()) ret=func.eval(*context) else: raise ValueError("f must be a ADFunctionCall or ADFunctionInterface") ctx.saved_state=(func,terminals,context) return torch.Tensor(ret)
[docs] @staticmethod def backward(ctx,grad): ret = [None,None] func,terminals,context = ctx.saved_state if isinstance(func,ad.ADFunctionCall): for k in range(len(terminals)): if isinstance(terminals[k],ad.ADTerminal): name = terminals[k].name else: name = terminals[k] deriv=torch.Tensor(func.derivative(name,**context)) ret.append(deriv.T@grad) elif isinstance(func,ad.ADFunctionInterface): for k in range(len(terminals)): deriv=torch.Tensor(func.derivative(k,*context)) ret.append(deriv.T@grad) else: raise ValueError("f must be a ADFunctionCall or ADFunctionInterface") return tuple(ret)
[docs] @staticmethod def check_derivatives_torch(func,terminals,h=1e-6,rtol=1e-2,atol=1e-3): #sample some random parameters of the appropriate length if isinstance(func,ad.ADFunctionInterface): params=[] for i in range(len(terminals)): try: N = func.n_in(i) if N < 0: N = 10 except NotImplementedError: N = 10 params.append(torch.randn(N)) else: N = 10 params = [torch.randn(N) for i in range(len(terminals))] for p in params: p.requires_grad_(True) torch.autograd.gradcheck(ADModule.apply,tuple([func,terminals]+params),eps=h,atol=atol,rtol=rtol,raise_exception=True)
[docs]def torch_to_ad(module,args): """Converts a PyTorch function applied to args (list of scalars or numpy arrays) to a Klamp't autodiff function call on those arguments.""" wrapper=TorchModuleFunction(module) return wrapper(*args)
[docs]def ad_to_torch(func,terminals=None): """Converts a Klamp't autodiff function call or function instance to a PyTorch Function. If terminals is provided, this is the list of arguments that PyTorch will expect. Otherwise, the variables in the expression will be automatically determined by the forward traversal order.""" if terminals is None: if isinstance(func,ad.ADFunctionCall): terminals = func.terminals() else: n_args = func.n_args() terminals = [func.argname(i) for i in range(n_args)] else: if isinstance(func,ad.ADFunctionCall): fterminals = func.terminals() if len(terminals) != len(fterminals): raise ValueError("The number of terminals provided is incorrect") for t in terminals: if isinstance(t,ad.ADTerminal): name = else: name = t if name not in fterminals: raise ValueError("Invalid terminal %s, function call %s only has terminals %s"%(name,str(func),str(terminals))) else: try: if len(terminals) != func.n_args(): raise ValueError("Invalid number of terminals, function %s expects %d"%(str(func),func.n_args())) except NotImplementedError: pass return ADModule(func,terminals)