2019-03-31 15:46:52 +02:00
# include <assert.h>
# include <cmath>
# include <ostream>
# include <Eigen/Dense>
# include "AutomaticDifferentiation_math.hpp"
// Convenience defines to make the code easier to read.
/// Type of the DualBase object.
# define __DualBase_t __Dual_DualBase<Scalar, N>
/// Template declaration of DualBase object.
# define __tplDualBase_t template<typename Scalar, int N>
/// Implementation of dual numbers for automatic differentiation.
///
2019-03-31 19:02:52 +02:00
/// Description
/// -----------
/// Dual numbers serve to compute arbitrary function gradients efficiently and easily, with machine precision.
/// They can be used as a drop-in replacement for any of the base arithmetic types (double, float, etc).
/// The dual numbers are used in the so called forward automatic differentiation.
2019-03-31 15:46:52 +02:00
///
2019-03-31 19:02:52 +02:00
/// Here are some of the advantages of forward automatic differentiation compared to the backward method :
///
/// - Contrary to the backward method, no graph must be computed, and the memory footprint of the forward method is greatly reduced.
/// - Contrary to popular belief, there *IS* a way to compute the whole gradient of a function using only a *single* function call (see examples below).
/// - The foward method is the *only one* that can be used to compute gradients of complicated numerical functions, such as the result of a numerical integration.
/// The backward method in these cases explodes the memory limit and crawls to a stop as it tries to record *all* the operations involved in the evaluation of the function.
///
/// Template parameters and sub-classes
/// -----------------------------------
/// There are 3 things to consider when using the Dual class to compute gradients :
///
/// - What arithmetic type is going to be used.
/// - The number of variables with respect to which the gradient will be computed.
/// - Whether the vectors should be allocated dynamically on the heap or statically on the stack.
///
/// Since the class defined as a template, any arithmetic type will do. Typically (but not limited to) :
///
/// - float
/// - double
/// - long double
/// - quad float (quadmath)
/// - arbitrary precision (boost, gmp, mpfr, ...)
/// - fixed precision
/// - ...
///
/// The number of variables with respect to which the gradient will be computed depends entirely on the problem to be solved.
///
/// Finally, the type of memory management depends on how many variables will be part of the gradient computation : since the vector type used is provided by Eigen,
/// their recommendation should be followed :
///
/// - Static for N < ~15 -> DualS<>
/// - Dynamic for N > ~15 -> DualD<>
///
/// Typical use cases
/// -----------------
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~{.cpp}
/// Scalar xf = ..., yf = ..., zf = ...;
/// Dual<Scalar, 3> x(xf), y(yf), z(zf), fx;
/// x.diff(0); y.diff(1); z.diff(2);
/// fx = f(x, y, z); // <--- A single function call !
/// Scalar dfdx = fx.d(0),
/// dfdy = fx.d(1),
/// dfdz = fx.d(2);
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Reference for the underlaying mathematics : http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.89.7749&rep=rep1&type=pdf
2019-03-31 15:46:52 +02:00
template < typename Scalar , int N >
struct __Dual_DualBase
{
static_assert ( N > 0 , " N must be > 0. " ) ;
using VectorT = __Dual_VectorT ;
/// Sets all the elements of the b vector to 0.
void SetBToZero ( )
{
# if __Dual_bdynamic == 1
b = VectorT : : Zero ( N ) ;
# else
b = VectorT : : Zero ( ) ;
# endif
}
2019-03-31 19:02:52 +02:00
__Dual_DualBase ( const Scalar & _a = Scalar ( ) )
: a ( _a )
{ SetBToZero ( ) ; }
__Dual_DualBase ( const Scalar & _a , const VectorT & _b )
: a ( _a )
{
assert ( _b . size ( ) = = N ) ;
b = _b ;
}
/// Use this function to set what variable is to be derived.
///
/// The two following statements are *not exactly* equivalent, but produce the same effect (the two last cases are equivalent) :
///
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~{.cpp}
/// // Using Dual::D(int i)
/// Dual<Scalar, 3> x(xf), y(yf), z(zf), fx;
/// fx = f(x+Dual::D(0), y+Dual::D(1), z+Dual::D(2));
/// Scalar dfdx = fx.d(0);
/// Scalar dfdy = fx.d(1);
/// Scalar dfdz = fx.d(2);
///
/// // Using diff(int i) before the function call
/// Dual<Scalar, 3> x(xf), y(yf), z(zf), fx;
/// x.diff(0); y.diff(1); z.diff(2);
/// fx = f(x, y, z);
/// Scalar dfdx = fx.d(0);
/// Scalar dfdy = fx.d(1);
/// Scalar dfdz = fx.d(2);
///
/// // Using diff(int i) directly during the function call
/// Dual<Scalar, 3> x(xf), y(yf), z(zf), fx;
/// fx = f(x.diff(0), y.diff(1), z.diff(2));
/// Scalar dfdx = fx.d(0);
/// Scalar dfdy = fx.d(1);
/// Scalar dfdz = fx.d(2);
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
///
static __Dual_DualBase D ( int i = 0 )
{
assert ( i > = 0 ) ;
assert ( i < N ) ;
__Dual_DualBase res ( Scalar ( 0 ) ) ;
res . b [ i ] = Scalar ( 1 ) ;
return res ;
}
/// Use this function to set what variable is to be derived. Only one derivative can be toggled at once using this function.
/// For example, If x.b = {1 1 0}, after transformation y = f(x), y.b = {dy/dx, dy/dx, 0}
///
/// Only one derivative should be selected per variable.
///
/// In order to compute the gradient of a function, the following code can be used :
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~{.cpp}
/// Dual<S> x(xd), y(yd), z(zd), fxyz;
/// x.diff(0,3); // Set the first derivative to be that of x.
/// y.diff(1,3); // Set the first derivative to be that of y.
/// z.diff(2,3); // Set the first derivative to be that of z.
/// fxyz = f(x, y, z); // Evaluate the function to differentiate.
/// S dfdx = fxyz.d(0); // df/dx
/// S dfdy = fxyz.d(1); // df/dy
/// S dfdz = fxyz.d(2); // df/dz
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
__Dual_DualBase const & diff ( int i = 0 )
{
assert ( i > = 0 ) ;
assert ( i < N ) ;
SetBToZero ( ) ;
b [ i ] = Scalar ( 1 ) ;
return * this ;
}
/// Returns a reference to the value
Scalar const & x ( ) const { return a ; }
Scalar & x ( ) { return a ; }
/// Returns a reference to the vector of infinitesimal parts
VectorT const & B ( ) const { return b ; }
VectorT & B ( ) { return b ; }
/// Returns the derivative value at index i. An assertion protects against i >= N.
Scalar const & d ( int i ) const
{
assert ( i > = 0 ) ;
assert ( i < N ) ;
return b [ i ] ;
}
/// Returns the derivative value at index i. An assertion protects against i >= N.
Scalar & d ( int i )
{
assert ( i > = 0 ) ;
assert ( i < N ) ;
return b [ i ] ;
}
__Dual_DualBase & operator + = ( const __Dual_DualBase & x )
{
a + = x . a ;
b + = x . b ;
return * this ;
}
__Dual_DualBase & operator - = ( const __Dual_DualBase & x )
{
a - = x . a ;
b - = x . b ;
return * this ;
}
__Dual_DualBase & operator * = ( const __Dual_DualBase & x )
{
b = a * x . b + b * x . a ;
a * = x . a ;
return * this ;
}
__Dual_DualBase & operator / = ( const __Dual_DualBase & x )
{
b = ( x . a * b - a * x . b ) / ( x . a * x . a ) ;
a / = x . a ;
return * this ;
}
__Dual_DualBase & operator + + ( ) { return ( ( * this ) + = Scalar ( 1. ) ) ; } // ++x
__Dual_DualBase & operator - - ( ) { return ( ( * this ) - = Scalar ( 1. ) ) ; } // --x
__Dual_DualBase operator + + ( int ) { // x++
__Dual_DualBase copy = * this ;
( * this ) + = Scalar ( 1. ) ;
return copy ;
}
__Dual_DualBase operator - - ( int ) { // x--
__Dual_DualBase copy = * this ;
( * this ) - = Scalar ( 1. ) ;
return copy ;
}
__Dual_DualBase operator + ( const __Dual_DualBase & x ) const {
__Dual_DualBase res ( * this ) ;
return ( res + = x ) ;
}
__Dual_DualBase operator - ( const __Dual_DualBase & x ) const {
__Dual_DualBase res ( * this ) ;
return ( res - = x ) ;
}
__Dual_DualBase operator * ( const __Dual_DualBase & x ) const
{
__Dual_DualBase res ( * this ) ;
return ( res * = x ) ;
}
__Dual_DualBase operator / ( const __Dual_DualBase & x ) const
{
__Dual_DualBase res ( * this ) ;
return ( res / = x ) ;
}
__Dual_DualBase operator + ( void ) const { return ( * this ) ; } // +x
__Dual_DualBase operator - ( void ) const { return __Dual_DualBase ( - a , - b ) ; } // -x
bool operator = = ( const __Dual_DualBase & x ) const { return ( a = = x . a ) ; }
bool operator ! = ( const __Dual_DualBase & x ) const { return ( a ! = x . a ) ; }
bool operator < ( const __Dual_DualBase & x ) const { return ( a < x . a ) ; }
bool operator < = ( const __Dual_DualBase & x ) const { return ( a < = x . a ) ; }
bool operator > ( const __Dual_DualBase & x ) const { return ( a > x . a ) ; }
bool operator > = ( const __Dual_DualBase & x ) const { return ( a > = x . a ) ; }
/// Explicit conversion of the dual number to *ANY* type. Clearely, not every conversion makes sense. Use at your own risk.
template < typename T > explicit operator T ( ) const { return static_cast < T > ( a ) ; }
Scalar a ; ///< Real part
VectorT b ; ///< Infinitesimal parts
2019-03-31 15:46:52 +02:00
} ;
template < typename A , typename B , int N >
__Dual_DualBase < B , N > operator + ( A const & v , __Dual_DualBase < B , N > const & x ) {
return ( __Dual_DualBase < B , N > ( v ) + x ) ;
}
template < typename A , typename B , int N >
__Dual_DualBase < B , N > operator - ( A const & v , __Dual_DualBase < B , N > const & x ) {
return ( __Dual_DualBase < B , N > ( v ) - x ) ;
}
template < typename A , typename B , int N >
__Dual_DualBase < B , N > operator * ( A const & v , __Dual_DualBase < B , N > const & x ) {
return ( __Dual_DualBase < B , N > ( v ) * x ) ;
}
template < typename A , typename B , int N >
__Dual_DualBase < B , N > operator / ( A const & v , __Dual_DualBase < B , N > const & x ) {
return ( __Dual_DualBase < B , N > ( v ) / x ) ;
}
// Basic mathematical functions for __Dual_DualBase numbers
// f(a + b*d) = f(a) + b*f'(a)*d
// Trigonometric functions
__tplDualBase_t __DualBase_t cos ( const __DualBase_t & x ) {
return __DualBase_t ( cos ( x . a ) , - x . b * sin ( x . a ) ) ;
}
__tplDualBase_t __DualBase_t sin ( const __DualBase_t & x ) {
return __DualBase_t ( sin ( x . a ) , x . b * cos ( x . a ) ) ;
}
__tplDualBase_t __DualBase_t tan ( const __DualBase_t & x ) {
return __DualBase_t ( tan ( x . a ) , x . b * sec ( x . a ) * sec ( x . a ) ) ;
}
__tplDualBase_t __DualBase_t sec ( const __DualBase_t & x ) {
return __DualBase_t ( sec ( x . a ) , x . b * sec ( x . a ) * tan ( x . a ) ) ;
}
__tplDualBase_t __DualBase_t cot ( const __DualBase_t & x ) {
return __DualBase_t ( cot ( x . a ) , x . b * ( - csc ( x . a ) * csc ( x . a ) ) ) ;
}
__tplDualBase_t __DualBase_t csc ( const __DualBase_t & x ) {
return __DualBase_t ( csc ( x . a ) , x . b * ( - cot ( x . a ) * csc ( x . a ) ) ) ;
}
// Inverse trigonometric functions
__tplDualBase_t __DualBase_t acos ( const __DualBase_t & x ) {
return __DualBase_t ( acos ( x . a ) , x . b * ( - Scalar ( 1. ) / sqrt ( Scalar ( 1. ) - x . a * x . a ) ) ) ;
}
__tplDualBase_t __DualBase_t asin ( const __DualBase_t & x ) {
return __DualBase_t ( asin ( x . a ) , x . b * ( Scalar ( 1. ) / sqrt ( Scalar ( 1. ) - x . a * x . a ) ) ) ;
}
__tplDualBase_t __DualBase_t atan ( const __DualBase_t & x ) {
return __DualBase_t ( atan ( x . a ) , x . b * ( Scalar ( 1. ) / ( x . a * x . a + Scalar ( 1. ) ) ) ) ;
}
__tplDualBase_t __DualBase_t asec ( const __DualBase_t & x ) {
return __DualBase_t ( asec ( x . a ) , x . b * ( Scalar ( 1. ) / ( sqrt ( Scalar ( 1. ) - Scalar ( 1. ) / ( x . a * x . a ) ) * ( x . a * x . a ) ) ) ) ;
}
__tplDualBase_t __DualBase_t acot ( const __DualBase_t & x ) {
return __DualBase_t ( acot ( x . a ) , x . b * ( - Scalar ( 1. ) / ( ( x . a * x . a ) + Scalar ( 1. ) ) ) ) ;
}
__tplDualBase_t __DualBase_t acsc ( const __DualBase_t & x ) {
return __DualBase_t ( acsc ( x . a ) , x . b * ( - Scalar ( 1. ) / ( sqrt ( Scalar ( 1. ) - Scalar ( 1. ) / ( x . a * x . a ) ) * ( x . a * x . a ) ) ) ) ;
}
// Hyperbolic trigonometric functions
__tplDualBase_t __DualBase_t cosh ( const __DualBase_t & x ) {
return __DualBase_t ( cosh ( x . a ) , x . b * sinh ( x . a ) ) ;
}
__tplDualBase_t __DualBase_t sinh ( const __DualBase_t & x ) {
return __DualBase_t ( sinh ( x . a ) , x . b * cosh ( x . a ) ) ;
}
__tplDualBase_t __DualBase_t tanh ( const __DualBase_t & x ) {
return __DualBase_t ( tanh ( x . a ) , x . b * sech ( x . a ) * sech ( x . a ) ) ;
}
__tplDualBase_t __DualBase_t sech ( const __DualBase_t & x ) {
return __DualBase_t ( sech ( x . a ) , x . b * ( - sech ( x . a ) * tanh ( x . a ) ) ) ;
}
__tplDualBase_t __DualBase_t coth ( const __DualBase_t & x ) {
return __DualBase_t ( coth ( x . a ) , x . b * ( - csch ( x . a ) * csch ( x . a ) ) ) ;
}
__tplDualBase_t __DualBase_t csch ( const __DualBase_t & x ) {
return __DualBase_t ( csch ( x . a ) , x . b * ( - coth ( x . a ) * csch ( x . a ) ) ) ;
}
// Inverse hyperbolic trigonometric functions
__tplDualBase_t __DualBase_t acosh ( const __DualBase_t & x ) {
return __DualBase_t ( acosh ( x . a ) , x . b * ( Scalar ( 1. ) / sqrt ( ( x . a * x . a ) - Scalar ( 1. ) ) ) ) ;
}
__tplDualBase_t __DualBase_t asinh ( const __DualBase_t & x ) {
return __DualBase_t ( asinh ( x . a ) , x . b * ( Scalar ( 1. ) / sqrt ( ( x . a * x . a ) + Scalar ( 1. ) ) ) ) ;
}
__tplDualBase_t __DualBase_t atanh ( const __DualBase_t & x ) {
return __DualBase_t ( atanh ( x . a ) , x . b * ( Scalar ( 1. ) / ( Scalar ( 1. ) - ( x . a * x . a ) ) ) ) ;
}
__tplDualBase_t __DualBase_t asech ( const __DualBase_t & x ) {
return __DualBase_t ( asech ( x . a ) , x . b * ( Scalar ( - 1. ) / ( sqrt ( Scalar ( 1. ) / ( x . a * x . a ) - Scalar ( 1. ) ) * ( x . a * x . a ) ) ) ) ;
}
__tplDualBase_t __DualBase_t acoth ( const __DualBase_t & x ) {
return __DualBase_t ( acoth ( x . a ) , x . b * ( - Scalar ( 1. ) / ( ( x . a * x . a ) - Scalar ( 1. ) ) ) ) ;
}
__tplDualBase_t __DualBase_t acsch ( const __DualBase_t & x ) {
return __DualBase_t ( acsch ( x . a ) , x . b * ( - Scalar ( 1. ) / ( sqrt ( Scalar ( 1. ) / ( x . a * x . a ) + Scalar ( 1. ) ) * ( x . a * x . a ) ) ) ) ;
}
// Exponential functions
__tplDualBase_t __DualBase_t exp ( const __DualBase_t & x ) {
return __DualBase_t ( exp ( x . a ) , x . b * exp ( x . a ) ) ;
}
__tplDualBase_t __DualBase_t log ( const __DualBase_t & x ) {
return __DualBase_t ( log ( x . a ) , x . b / x . a ) ;
}
__tplDualBase_t __DualBase_t exp10 ( const __DualBase_t & x ) {
return __DualBase_t ( exp10 ( x . a ) , x . b * ( log ( Scalar ( 10. ) ) * exp10 ( x . a ) ) ) ;
}
__tplDualBase_t __DualBase_t log10 ( const __DualBase_t & x ) {
return __DualBase_t ( log10 ( x . a ) , x . b / ( log ( Scalar ( 10. ) ) * x . a ) ) ;
}
__tplDualBase_t __DualBase_t exp2 ( const __DualBase_t & x ) {
return __DualBase_t ( exp2 ( x . a ) , x . b * ( log ( Scalar ( 2. ) ) * exp2 ( x . a ) ) ) ;
}
__tplDualBase_t __DualBase_t log2 ( const __DualBase_t & x ) {
return __DualBase_t ( log2 ( x . a ) , x . b / ( log ( Scalar ( 2. ) ) * x . a ) ) ;
}
__tplDualBase_t __DualBase_t pow ( const __DualBase_t & x , const __DualBase_t & n ) {
return exp ( n * log ( x ) ) ;
}
template < typename Scalar , typename Scalar2 , int N > __DualBase_t pow ( const __DualBase_t & x , const Scalar2 & n ) {
return exp ( __DualBase_t ( static_cast < Scalar > ( n ) ) * log ( x ) ) ;
}
template < typename Scalar , typename Scalar2 , int N > __DualBase_t pow ( const Scalar2 & x , const __DualBase_t & n ) {
return exp ( __DualBase_t ( n ) * log ( static_cast < Scalar > ( x ) ) ) ;
}
// Other functions
__tplDualBase_t __DualBase_t sqrt ( const __DualBase_t & x ) {
return __DualBase_t ( sqrt ( x . a ) , x . b / ( Scalar ( 2. ) * sqrt ( x . a ) ) ) ;
}
__tplDualBase_t __DualBase_t cbrt ( const __DualBase_t & x ) {
return __DualBase_t ( cbrt ( x . a ) , x . b / ( Scalar ( 3. ) * pow ( x . a , Scalar ( 2. ) / Scalar ( 3. ) ) ) ) ;
}
__tplDualBase_t __DualBase_t sign ( const __DualBase_t & x ) {
return __DualBase_t ( sign ( x . a ) ) ;
}
__tplDualBase_t __DualBase_t abs ( const __DualBase_t & x ) {
return __DualBase_t ( abs ( x . a ) , x . b * sign ( x . a ) ) ;
}
__tplDualBase_t __DualBase_t fabs ( const __DualBase_t & x ) {
return __DualBase_t ( fabs ( x . a ) , x . b * sign ( x . a ) ) ;
}
__tplDualBase_t __DualBase_t heaviside ( const __DualBase_t & x ) {
return __DualBase_t ( heaviside ( x . a ) ) ;
}
__tplDualBase_t __DualBase_t floor ( const __DualBase_t & x ) {
return __DualBase_t ( floor ( x . a ) ) ;
}
__tplDualBase_t __DualBase_t ceil ( const __DualBase_t & x ) {
return __DualBase_t ( ceil ( x . a ) ) ;
}
__tplDualBase_t __DualBase_t round ( const __DualBase_t & x ) {
return __DualBase_t ( round ( x . a ) ) ;
}
__tplDualBase_t std : : ostream & operator < < ( std : : ostream & s , const __Dual_DualBase < Scalar , N > & x )
{
return ( s < < x . a ) ;
}
// -----------------------------------------------------------------------------------------------------------
/*
/// Macro to create a function object that returns the gradient of the function at X.
/// Designed to work with functions, lambdas, etc.
# define CREATE_GRAD_FUNCTION_OBJECT(Func, GradFuncName) \
struct GradFuncName { \
template < typename Scalar > \
Scalar operator ( ) ( Scalar const & x ) { \
__Dual_DualBase < Scalar > X ( x ) ; \
X . diff ( 0 , 1 ) ; \
__Dual_DualBase < Scalar > Y = Func < __Dual_DualBase < Scalar > > ( X ) ; \
return Y . d ( 0 ) ; \
} \
template < typename Scalar > \
void get_f_grad ( Scalar const & x , Scalar & fx , Scalar & gradfx ) { \
__Dual_DualBase < Scalar > X ( x ) ; \
X . diff ( 0 , 1 ) ; \
__Dual_DualBase < Scalar > Y = Func < __Dual_DualBase < Scalar > > ( X ) ; \
fx = Y . x ( ) ; \
gradfx = Y . d ( 0 ) ; \
} \
}
//*/
/*
/// Macro to create a function object that returns the gradient of the function at X.
/// Designed to work with function objects.
# define CREATE_GRAD_FUNCTION_OBJECT_FUNCTOR(Func, GradFuncName) \
struct GradFuncName { \
template < typename Scalar > \
Scalar operator ( ) ( Scalar const & x ) { \
__Dual_DualBase < Scalar > X ( x ) ; \
X . diff ( 0 , 1 ) ; \
Func f ; \
__Dual_DualBase < Scalar > Y = f ( X ) ; \
return Y . d ( 0 ) ; \
} \
template < typename Scalar > \
void get_f_grad ( Scalar const & x , Scalar & fx , Scalar & gradfx ) { \
__Dual_DualBase < Scalar > X ( x ) ; \
X . diff ( 0 , 1 ) ; \
Func f ; \
__Dual_DualBase < Scalar > Y = f ( X ) ; \
fx = Y . x ( ) ; \
gradfx = Y . d ( 0 ) ; \
} \
}
//*/