AutomaticDifferentiation/test_AutomaticDifferentiation_manual.cpp

504 lines
14 KiB
C++
Raw Normal View History

2019-03-24 19:13:06 +01:00
#include "AutomaticDifferentiation.hpp"
#include <vector>
#include <cassert>
#include <iostream>
#include <iomanip>
using std::cout;
using std::cout;
using std::setw;
#define PRINT_VAR(x) std::cout << #x << "\t= " << std::setprecision(16) << (x) << std::endl
#define PRINT_DUAL(x) std::cout << #x << "\t= " << std::fixed << std::setprecision(4) << std::setw(10) << (x).a << ", " << std::setw(10) << (x).b << std::endl
#define TEST_EQ_TOL 1e-9
// #define TEST_FUNCTION_ON_DUAL_DOUBLE(fct, x) assert(fct(Dual<double>(x)).a == fct(x))
// #define TEST_FUNCTION_ON_DUAL_DOUBLE(fct, x) assert(abs((fct(Dual<double>((x))).a) - (fct((x)))) <= TEST_EQ_TOL)
#define TEST_EQ_DOUBLE(a, b) assert(abs((a) - (b)) <= TEST_EQ_TOL)
#define TEST_FUNCTION_ON_DUAL_DOUBLE(fct, x) \
if(abs((fct(Dual<double>((x))).a) - fct((x))) > TEST_EQ_TOL) { \
std::cerr << "Assertion failed at " << __FILE__ << ":" << __LINE__ << " : " << #fct << "<Dual<double>>(" << (x) << ") != " << #fct << "(" << (x) << ")" << "\n";\
std::cerr << "Got " << (fct(Dual<double>(x)).a) << " ; expected " << (fct((x))) << "\n";\
exit(1);\
}
template<typename T> void print_T() { std::cout << __PRETTY_FUNCTION__ << '\n'; }
template<typename Scalar>
Scalar f1(const Scalar & x)
{
return Scalar(5.)*x*x*x + Scalar(3.)*x*x - Scalar(2.)*x + Scalar(4.);
}
template<typename Scalar>
Scalar df1(const Scalar & x)
{
return Scalar(15.)*x*x + Scalar(6.)*x - Scalar(2.);
}
template<typename Scalar>
Scalar ddf1(const Scalar & x)
{
return Scalar(30.)*x + Scalar(6.);
}
template<typename Scalar>
Scalar g1(Scalar x) {
return f1(Dual<Scalar>(x) + Dual<Scalar>::d()).b;
}
template<typename Scalar>
Scalar h1(Scalar x) {
return g1(Dual<Scalar>(x) + Dual<Scalar>::d()).b;
}
template<class D> D f2(D x) {
return (x + D(2.0)) * (x + D(1.0));
}
template<typename Scalar>
Scalar f3(const Scalar & x, const Scalar & y, const Scalar & z)
{
return sqrt(z*z+y*y+x*x);
}
template<typename Scalar>
Scalar df3x(const Scalar & x, const Scalar & y, const Scalar & z)
{
return x/sqrt(z*z+y*y+x*x);
}
template<typename Scalar>
Scalar df3y(const Scalar & x, const Scalar & y, const Scalar & z)
{
return y/sqrt(z*z+y*y+x*x);
}
template<typename Scalar>
Scalar df3z(const Scalar & x, const Scalar & y, const Scalar & z)
{
return z/sqrt(z*z+y*y+x*x);
}
void test_basic();
void test_scalar_functions();
void test_derivative_all();
void test_derivative_pow();
void test_derivative_simple();
void test_derivative_simple_2();
void test_derivative_simple_3();
void test_derivative_nested();
int main()
{
test_scalar_functions();
test_basic();
test_derivative_all();
test_derivative_pow();
// test_derivative_simple();
// test_derivative_simple_2();
// test_derivative_simple_3();
// test_derivative_nested();
return 0;
}
void test_basic()
{
cout << "\ntest_basic()\n";
using D = Dual<double>;
cout.precision(16);
double x = 2;
double y = 5;
D X(x), Y(y), Z(x, y);
D W = Z;
assert(X.a == x);
assert(X.b == 0.);
assert(Y.a == y);
assert(Y.b == 0.);
assert((X+Y).a == (x+y));
assert((X-Y).a == (x-y));
assert((X*Y).a == (x*y));
assert((X/Y).a == (x/y));
assert(-X.a == -x);
assert(D(1., 2.).a == 1.);
assert(D(1., 2.).b == 2.);
assert(W.a == Z.a);
assert(W.b == Z.b);
assert((D(1., 2.)+D(4., 7.)).a == 5.);
assert((D(1., 2.)+D(4., 7.)).b == 9.);
// test all the value returned by non-linear functions
assert(heaviside(-1.) == 0.);
assert(heaviside(0.) == 1.);
assert(heaviside(1.) == 1.);
assert(sign(-10.) == -1.);
assert(sign(0.) == 0.);
assert(sign(10.) == 1.);
PRINT_VAR(abs(-10.));
PRINT_VAR(abs(10.));
PRINT_VAR(abs(Dual<double>(-10.)).a);
PRINT_VAR(abs(Dual<double>(10.)).a);
PRINT_VAR(abs(Dual<double>(-1.62)).a);
PRINT_VAR(abs(-1.62));
PRINT_VAR(exp10(-3.));
PRINT_VAR(exp10(-2.));
PRINT_VAR(exp10(-1.));
PRINT_VAR(exp10(0.));
PRINT_VAR(exp10(1.));
PRINT_VAR(exp10(2.));
PRINT_VAR(exp10(3.));
x = -1.5;
PRINT_VAR((x >= (0.)) ? pow((10.), x) : (1.)/pow((10.), -x));
PRINT_VAR(pow(10., -3.));
PRINT_VAR(pow(10., 3.));
PRINT_VAR(1./pow(10., 3.));
PRINT_VAR(exp2(-3.));
PRINT_VAR(exp2(3.));
TEST_EQ_DOUBLE(exp10(-3.), 1./exp10(3.));
TEST_EQ_DOUBLE(exp10(-3.), 0.001);
TEST_EQ_DOUBLE(exp10( 3.), 1000.);
PRINT_VAR(atanh(0.62));
PRINT_VAR(atanh(Dual<double>(0.62)).a);
PRINT_VAR(acsc(Dual<double>(1.62)).a);
PRINT_VAR(acsc(1.62));
TEST_EQ_DOUBLE(pow(Dual<double>(1.62), Dual<double>(1.5)).a, pow(1.62, 1.5));
// trigonometric functions
TEST_FUNCTION_ON_DUAL_DOUBLE(cos, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(sin, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(tan, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(sec, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(cot, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(csc, 1.62);
// inverse trigonometric functions
TEST_FUNCTION_ON_DUAL_DOUBLE(acos, 0.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(asin, 0.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(atan, 0.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(asec, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(acot, 0.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(acsc, 1.62);
// exponential functions
TEST_FUNCTION_ON_DUAL_DOUBLE(exp, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(log, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(exp10, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(log10, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(exp2, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(log2, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(sqrt, 1.62);
// hyperbolic trigonometric functions
TEST_FUNCTION_ON_DUAL_DOUBLE(cosh, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(sinh, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(tanh, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(sech, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(coth, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(csch, 1.62);
// inverse hyperbolic trigonometric functions
TEST_FUNCTION_ON_DUAL_DOUBLE(acosh, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(asinh, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(atanh, 0.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(asech, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(acoth, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(acsch, 1.62);
// other functions
TEST_FUNCTION_ON_DUAL_DOUBLE(sign, -1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(sign, 0.00);
TEST_FUNCTION_ON_DUAL_DOUBLE(sign, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(abs, -1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(abs, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(heaviside, -1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(heaviside, 0.00);
TEST_FUNCTION_ON_DUAL_DOUBLE(heaviside, 1.62);
TEST_FUNCTION_ON_DUAL_DOUBLE(floor, 1.2);
TEST_FUNCTION_ON_DUAL_DOUBLE(ceil, 1.2);
TEST_FUNCTION_ON_DUAL_DOUBLE(round, 1.2);
}
void test_scalar_functions()
{
// test basic scalar functions numerically with values from mathematica
cout << "\ntest_scalar_functions()\n";
// trigonometric functions
//{-0.04918382191417056,0.998789743470524,-20.30728204110463,-20.33188884233264,-0.04924341908365026,1.001211723025179}
TEST_EQ_DOUBLE(cos(1.62), -0.04918382191417056);
TEST_EQ_DOUBLE(sin(1.62), 0.998789743470524);
TEST_EQ_DOUBLE(tan(1.62), -20.30728204110463);
TEST_EQ_DOUBLE(sec(1.62), -20.33188884233264);
TEST_EQ_DOUBLE(cot(1.62), -0.04924341908365026);
TEST_EQ_DOUBLE(csc(1.62), 1.001211723025179);
// inverse trigonometric functions
// {0.902053623592525,0.6687427032023717,0.5549957273385867,0.905510600165641,1.01580059945631,0.6652857266292561}
TEST_EQ_DOUBLE(acos(0.62), 0.902053623592525);
TEST_EQ_DOUBLE(asin(0.62), 0.6687427032023717);
TEST_EQ_DOUBLE(atan(0.62), 0.5549957273385867);
TEST_EQ_DOUBLE(asec(1.62), 0.905510600165641);
TEST_EQ_DOUBLE(acot(0.62), 1.01580059945631);
TEST_EQ_DOUBLE(acsc(1.62), 0.6652857266292561);
// hyperbolic trigonometric functions
// {2.625494507823741,2.427595808740127,0.924624218982788,0.3808806291615117,1.081520448491102,0.4119301888723312}
TEST_EQ_DOUBLE(cosh(1.62), 2.625494507823741);
TEST_EQ_DOUBLE(sinh(1.62), 2.427595808740127);
TEST_EQ_DOUBLE(tanh(1.62), 0.924624218982788);
TEST_EQ_DOUBLE(sech(1.62), 0.3808806291615117);
TEST_EQ_DOUBLE(coth(1.62), 1.081520448491102);
TEST_EQ_DOUBLE(csch(1.62), 0.4119301888723312);
// inverse hyperbolic trigonometric functions
// {1.062819127408777,1.259535895278778,0.7250050877529992,1.057231115568124,0.7206050593580027,0.5835891509960214}
TEST_EQ_DOUBLE(acosh(1.62), 1.062819127408777);
TEST_EQ_DOUBLE(asinh(1.62), 1.259535895278778);
TEST_EQ_DOUBLE(atanh(0.62), 0.7250050877529992);
TEST_EQ_DOUBLE(asech(0.62), 1.057231115568124);
TEST_EQ_DOUBLE(acoth(1.62), 0.7206050593580027);
TEST_EQ_DOUBLE(acsch(1.62), 0.5835891509960214);
}
#define TEST_DERIVATIVE_NUM(fct, x, DX_NUM_DIFF, tol); \
{\
double dfdx = fct((Dual<double>((x))+Dual<double>::d())).b;\
double dfdx_num = ((fct((x)+DX_NUM_DIFF)) - (fct((x)-DX_NUM_DIFF)))/(2*DX_NUM_DIFF);\
bool ok = (abs(dfdx - dfdx_num) <= tol) ? true : false;\
cout << setw(10) << #fct << "(" << (x) << ") : " << setw(25) << dfdx << " " << setw(25) << dfdx_num << " " << setw(25) << (dfdx-dfdx_num) << "\t" << ok << "\n";\
}
void test_derivative_all()
{
cout << "\ntest_derivative_all()\n";
// test all derivatives numerically and check that they add up
double x = 0.5, x2 = 1.5, x3 = -x, dx = 1e-6, tol = 1e-9;
TEST_DERIVATIVE_NUM(cos, x, dx, tol);
TEST_DERIVATIVE_NUM(sin, x, dx, tol);
TEST_DERIVATIVE_NUM(tan, x, dx, tol);
TEST_DERIVATIVE_NUM(sec, x, dx, tol);
TEST_DERIVATIVE_NUM(cot, x, dx, tol);
TEST_DERIVATIVE_NUM(csc, x, dx, tol);
TEST_DERIVATIVE_NUM(acos, x, dx, tol);
TEST_DERIVATIVE_NUM(asin, x, dx, tol);
TEST_DERIVATIVE_NUM(atan, x, dx, tol);
TEST_DERIVATIVE_NUM(asec, x2, dx, tol);
TEST_DERIVATIVE_NUM(acot, x, dx, tol);
TEST_DERIVATIVE_NUM(acsc, x2, dx, tol);
TEST_DERIVATIVE_NUM(cosh, x, dx, tol);
TEST_DERIVATIVE_NUM(sinh, x, dx, tol);
TEST_DERIVATIVE_NUM(tanh, x, dx, tol);
TEST_DERIVATIVE_NUM(sech, x, dx, tol);
TEST_DERIVATIVE_NUM(coth, x, dx, tol);
TEST_DERIVATIVE_NUM(csch, x, dx, tol);
TEST_DERIVATIVE_NUM(acosh, x2, dx, tol);
TEST_DERIVATIVE_NUM(asinh, x, dx, tol);
TEST_DERIVATIVE_NUM(atanh, x, dx, tol);
TEST_DERIVATIVE_NUM(asech, x, dx, tol);
TEST_DERIVATIVE_NUM(acoth, x2, dx, tol);
TEST_DERIVATIVE_NUM(acsch, x, dx, tol);
TEST_DERIVATIVE_NUM(exp, x2, dx, tol);
TEST_DERIVATIVE_NUM(log, x, dx, tol);
TEST_DERIVATIVE_NUM(exp10, x, dx, tol);
TEST_DERIVATIVE_NUM(log10, x, dx, tol);
TEST_DERIVATIVE_NUM(exp2, x2, dx, tol);
TEST_DERIVATIVE_NUM(log2, x, dx, tol);
TEST_DERIVATIVE_NUM(sqrt, x2, dx, tol);
TEST_DERIVATIVE_NUM(sign, x3, dx, tol);
TEST_DERIVATIVE_NUM(sign, x, dx, tol);
TEST_DERIVATIVE_NUM(abs, x3, dx, tol);
TEST_DERIVATIVE_NUM(abs, x, dx, tol);
TEST_DERIVATIVE_NUM(heaviside, x3, dx, tol);
TEST_DERIVATIVE_NUM(heaviside, x, dx, tol);
TEST_DERIVATIVE_NUM(floor, 1.6, dx, tol);
TEST_DERIVATIVE_NUM(ceil, 1.6, dx, tol);
TEST_DERIVATIVE_NUM(round, 1.6, dx, tol);
}
void test_derivative_pow()
{
// test the derivatives of the power function
cout << "\ntest_derivative_pow()\n";
using D = Dual<double>;
double a = 1.5, b = 5.4;
double c = pow(a, b);
double dcda = b*pow(a, b-1);
double dcdb = pow(a, b)*log(a);
D A(a), B(b);
PRINT_VAR(a);
PRINT_VAR(b);
PRINT_VAR(c);
PRINT_VAR(dcda);
PRINT_VAR(dcdb);
PRINT_DUAL(pow(A, B));
PRINT_DUAL(pow(A+D::d(), B));
PRINT_DUAL(pow(A, B+D::d()));
double dcda_AD = pow(A+D::d(), B).b;
double dcdb_AD = pow(A, B+D::d()).b;
TEST_EQ_DOUBLE(pow(A, B).a, c);
TEST_EQ_DOUBLE(dcda_AD, dcda);
TEST_EQ_DOUBLE(dcdb_AD, dcdb);
}
void test_derivative_simple()
{
cout << "\ntest_derivative_simple()\n";
using D = Dual<double>;
D d(0., 1.);
D x = 3.5;
D y = f1(x+d);
D dy = df1(x);
D ddy = ddf1(x);
PRINT_DUAL(x);
PRINT_DUAL(x+d);
PRINT_DUAL(y);
PRINT_DUAL(dy);
PRINT_VAR(g1(x));
PRINT_VAR(h1(x.a));
PRINT_VAR(ddy);
assert(y.b == dy.a);
}
void test_derivative_simple_2()
{
cout << "\ntest_derivative_simple_2()\n";
using D = Dual<double>;
D d(0., 1.);
D x = 3.;
D x2 = x+d;
D y = f2(x);
D y2 = f2(x2);
D y3 = f2(D(3.0)+d);
PRINT_DUAL(x);
PRINT_DUAL(x+d);
PRINT_DUAL(y);
PRINT_DUAL(y2);
PRINT_DUAL(y3);
assert(y.a == 20.);
assert(y2.a == 20.);
assert(y3.a == 20.);
assert(y2.b == 9.);
assert(y3.b == 9.);
}
void test_derivative_simple_3()
{
// partial derivatives using scalar implementation
cout << "\ntest_derivative_simple_3()\n";
using D = Dual<double>;
D x(2.), y(4.), z(-7.);
D L = f3(x, y, z);
PRINT_DUAL(L);
PRINT_VAR(f3(x+D::d(), y, z).b);
PRINT_VAR(f3(x, y+D::d(), z).b);
PRINT_VAR(f3(x, y, z+D::d()).b);
PRINT_VAR(df3x(x.a, y.a, z.a));
PRINT_VAR(df3y(x.a, y.a, z.a));
PRINT_VAR(df3z(x.a, y.a, z.a));
}
template<typename Scalar>
Scalar testFunctionNesting(const Scalar & x)
{
return Scalar(1.)/atan(Scalar(1.) - pow(x, Scalar(2.)));
}
template<typename Scalar>
struct TestFunctionNestingFunctor
{
Scalar operator()(const Scalar & x)
{
return Scalar(1.)/atan(Scalar(1.) - pow(x, Scalar(2.)));
}
};
/* template<typename Scalar, typename FunctorType>
std::vector<Scalar> computeNfirstDerivatives(FunctorType & functor, const Scalar & x, const int & n, int depth = 0)
{
// compute the n first derivatives using the recursive, nested approach
std::vector<Scalar> derivatives;
if(depth < n)
derivatives.push_back(computeNfirstDerivatives(functor, Dual<Scalar>(x) + Dual<Scalar>::d(), n, ++depth)[0].b);
return derivatives;
}
void test_derivative_nested()
{
// nested function to compute the first nth derivatives using scalar implementation
cout << "\ntest_derivative_nested()\n";
using D = Dual<double>;
double x = 0.7;
TestFunctionNestingFunctor<double> ffun;
std::vector<double> derivatives = computeNfirstDerivatives<double, TestFunctionNestingFunctor<double> >(ffun, x, 1);
} */