AutomaticDifferentiation/old/test_AutomaticDifferentiation_vector.cpp

473 lines
14 KiB
C++
Executable file

#include "AutomaticDifferentiationVector.hpp"
#include <assert.h>
#include <iostream>
#include <iomanip>
using std::cout;
using std::cout;
using std::setw;
using std::abs;
#define PRINT_VAR(x) std::cout << #x << "\t= " << std::setprecision(16) << (x) << std::endl
#define PRINT_DUALVECTOR(x) std::cout << #x << "\t= " << std::fixed << std::setprecision(4) << std::setw(10) << (x).a << ", " << std::setw(10) << (x).b << std::endl
#define TEST_EQ_TOL 1e-9
// #define TEST_FUNCTION_ON_DualVector_DOUBLE(fct, x) assert(fct(DualVector<double,1>(x)).a == fct(x))
// #define TEST_FUNCTION_ON_DualVector_DOUBLE(fct, x) assert(abs((fct(DualVector<double,1>((x))).a) - (fct((x)))) <= TEST_EQ_TOL)
#define TEST_EQ_DOUBLE(a, b) assert(std::abs((a) - (b)) <= TEST_EQ_TOL)
#define TEST_FUNCTION_ON_DualVector_DOUBLE(fct, x) \
if(fabs((fct(DualVector<double,1>((x))).a) - fct((x))) > TEST_EQ_TOL) { \
std::cerr << "Assertion failed at " << __FILE__ << ":" << __LINE__ << " : " << #fct << "<DualVector<double,1>>(" << (x) << ") != " << #fct << "(" << (x) << ")" << "\n";\
std::cerr << "Got " << (fct(DualVector<double,1>(x)).a) << " ; expected " << (fct((x))) << "\n";\
exit(1);\
}
template<typename T> void print_T() { std::cout << __PRETTY_FUNCTION__ << '\n'; }
template<typename Scalar>
Scalar f1(const Scalar & x)
{
return Scalar(5.)*x*x*x + Scalar(3.)*x*x - Scalar(2.)*x + Scalar(4.);
}
template<typename Scalar>
Scalar df1(const Scalar & x)
{
return Scalar(15.)*x*x + Scalar(6.)*x - Scalar(2.);
}
template<typename Scalar>
Scalar ddf1(const Scalar & x)
{
return Scalar(30.)*x + Scalar(6.);
}
template<typename Scalar, typename Vector>
Vector g1(Scalar x) {
return f1(DualVector<Scalar,1>(x) + DualVector<Scalar,1>::d()).b;
}
template<typename Scalar, typename Vector>
Vector h1(Scalar x) {
return g1(DualVector<Scalar,1>(x) + DualVector<Scalar,1>::d()).b;
}
template<class D> D f2(D x) {
return (x + D(2.0)) * (x + D(1.0));
}
template<typename Scalar>
Scalar f3(const Scalar & x, const Scalar & y, const Scalar & z)
{
return sqrt(z*z+y*y+x*x);
}
template<typename Scalar>
Scalar df3x(const Scalar & x, const Scalar & y, const Scalar & z)
{
return x/sqrt(z*z+y*y+x*x);
}
template<typename Scalar>
Scalar df3y(const Scalar & x, const Scalar & y, const Scalar & z)
{
return y/sqrt(z*z+y*y+x*x);
}
template<typename Scalar>
Scalar df3z(const Scalar & x, const Scalar & y, const Scalar & z)
{
return z/sqrt(z*z+y*y+x*x);
}
void test_basic();
void test_scalar_functions();
void test_derivative_all();
void test_derivative_pow();
void test_derivative_simple();
void test_derivative_simple_2();
void test_derivative_simple_3();
int main()
{
test_scalar_functions();
test_basic();
test_derivative_all();
test_derivative_pow();
// test_derivative_simple();
test_derivative_simple_2();
test_derivative_simple_3();
return 0;
}
void test_basic()
{
cout << "\ntest_basic()\n";
using D = DualVector<double, 3>;
cout.precision(16);
DualVector<double, 3>::VectorT zeroVec(0., 3);
double x = 2;
double y = 5;
D X(x), Y(y);
assert(X.a == x);
for(size_t i = 0 ; i < 3 ; i++)
assert(X.b[i] == 0.);
assert(Y.a == y);
for(size_t i = 0 ; i < 3 ; i++)
assert(Y.b[i] == 0.);
assert((X+Y).a == (x+y));
assert((X-Y).a == (x-y));
assert((X*Y).a == (x*y));
assert((X/Y).a == (x/y));
assert(-X.a == -x);
PRINT_DUALVECTOR(X+Y);
PRINT_DUALVECTOR(X-Y);
PRINT_DUALVECTOR(X*Y);
PRINT_DUALVECTOR(X/Y);
assert(D(1., 2.).a == 1.);
// assert(D(1., 2.).b == 2.);
assert((D(1., 2.)+D(4., 7.)).a == 5.);
// assert((D(1., 2.)+D(4., 7.)).b == 9.);
// test all the value returned by non-linear functions
assert(heaviside(-1.) == 0.);
assert(heaviside(0.) == 1.);
assert(heaviside(1.) == 1.);
assert(sign(-10.) == -1.);
assert(sign(0.) == 0.);
assert(sign(10.) == 1.);
PRINT_VAR(abs(-10.));
PRINT_VAR(abs(10.));
PRINT_VAR(abs(DualVector<double,1>(-10.)).a);
PRINT_VAR(abs(DualVector<double,1>(10.)).a);
PRINT_VAR(abs(DualVector<double,1>(-1.62)).a);
PRINT_VAR(abs(-1.62));
PRINT_VAR(exp10(-3.));
PRINT_VAR(exp10(-2.));
PRINT_VAR(exp10(-1.));
PRINT_VAR(exp10(0.));
PRINT_VAR(exp10(1.));
PRINT_VAR(exp10(2.));
PRINT_VAR(exp10(3.));
x = -1.5;
PRINT_VAR((x >= (0.)) ? pow((10.), x) : (1.)/pow((10.), -x));
PRINT_VAR(pow(10., -3.));
PRINT_VAR(pow(10., 3.));
PRINT_VAR(1./pow(10., 3.));
PRINT_VAR(exp2(-3.));
PRINT_VAR(exp2(3.));
TEST_EQ_DOUBLE(exp10(-3.), 1./exp10(3.));
TEST_EQ_DOUBLE(exp10(-3.), 0.001);
TEST_EQ_DOUBLE(exp10( 3.), 1000.);
PRINT_VAR(atanh(0.62));
PRINT_VAR(atanh(DualVector<double,1>(0.62)).a);
PRINT_VAR(acsc(DualVector<double,1>(1.62)).a);
PRINT_VAR(acsc(1.62));
TEST_EQ_DOUBLE(pow(DualVector<double,1>(1.62), DualVector<double,1>(1.5)).a, pow(1.62, 1.5));
// trigonometric functions
TEST_FUNCTION_ON_DualVector_DOUBLE(cos, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(sin, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(tan, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(sec, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(cot, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(csc, 1.62);
// inverse trigonometric functions
TEST_FUNCTION_ON_DualVector_DOUBLE(acos, 0.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(asin, 0.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(atan, 0.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(asec, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(acot, 0.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(acsc, 1.62);
// exponential functions
TEST_FUNCTION_ON_DualVector_DOUBLE(exp, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(log, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(exp10, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(log10, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(exp2, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(log2, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(sqrt, 1.62);
// hyperbolic trigonometric functions
TEST_FUNCTION_ON_DualVector_DOUBLE(cosh, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(sinh, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(tanh, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(sech, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(coth, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(csch, 1.62);
// inverse hyperbolic trigonometric functions
TEST_FUNCTION_ON_DualVector_DOUBLE(acosh, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(asinh, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(atanh, 0.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(asech, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(acoth, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(acsch, 1.62);
// other functions
TEST_FUNCTION_ON_DualVector_DOUBLE(sign, -1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(sign, 0.00);
TEST_FUNCTION_ON_DualVector_DOUBLE(sign, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(abs, -1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(abs, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(heaviside, -1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(heaviside, 0.00);
TEST_FUNCTION_ON_DualVector_DOUBLE(heaviside, 1.62);
TEST_FUNCTION_ON_DualVector_DOUBLE(floor, 1.2);
TEST_FUNCTION_ON_DualVector_DOUBLE(ceil, 1.2);
TEST_FUNCTION_ON_DualVector_DOUBLE(round, 1.2);
}
void test_scalar_functions()
{
// test basic scalar functions numerically with values from mathematica
cout << "\ntest_scalar_functions()\n";
// trigonometric functions
//{-0.04918382191417056,0.998789743470524,-20.30728204110463,-20.33188884233264,-0.04924341908365026,1.001211723025179}
TEST_EQ_DOUBLE(cos(1.62), -0.04918382191417056);
TEST_EQ_DOUBLE(sin(1.62), 0.998789743470524);
TEST_EQ_DOUBLE(tan(1.62), -20.30728204110463);
TEST_EQ_DOUBLE(sec(1.62), -20.33188884233264);
TEST_EQ_DOUBLE(cot(1.62), -0.04924341908365026);
TEST_EQ_DOUBLE(csc(1.62), 1.001211723025179);
// inverse trigonometric functions
// {0.902053623592525,0.6687427032023717,0.5549957273385867,0.905510600165641,1.01580059945631,0.6652857266292561}
TEST_EQ_DOUBLE(acos(0.62), 0.902053623592525);
TEST_EQ_DOUBLE(asin(0.62), 0.6687427032023717);
TEST_EQ_DOUBLE(atan(0.62), 0.5549957273385867);
TEST_EQ_DOUBLE(asec(1.62), 0.905510600165641);
TEST_EQ_DOUBLE(acot(0.62), 1.01580059945631);
TEST_EQ_DOUBLE(acsc(1.62), 0.6652857266292561);
// hyperbolic trigonometric functions
// {2.625494507823741,2.427595808740127,0.924624218982788,0.3808806291615117,1.081520448491102,0.4119301888723312}
TEST_EQ_DOUBLE(cosh(1.62), 2.625494507823741);
TEST_EQ_DOUBLE(sinh(1.62), 2.427595808740127);
TEST_EQ_DOUBLE(tanh(1.62), 0.924624218982788);
TEST_EQ_DOUBLE(sech(1.62), 0.3808806291615117);
TEST_EQ_DOUBLE(coth(1.62), 1.081520448491102);
TEST_EQ_DOUBLE(csch(1.62), 0.4119301888723312);
// inverse hyperbolic trigonometric functions
// {1.062819127408777,1.259535895278778,0.7250050877529992,1.057231115568124,0.7206050593580027,0.5835891509960214}
TEST_EQ_DOUBLE(acosh(1.62), 1.062819127408777);
TEST_EQ_DOUBLE(asinh(1.62), 1.259535895278778);
TEST_EQ_DOUBLE(atanh(0.62), 0.7250050877529992);
TEST_EQ_DOUBLE(asech(0.62), 1.057231115568124);
TEST_EQ_DOUBLE(acoth(1.62), 0.7206050593580027);
TEST_EQ_DOUBLE(acsch(1.62), 0.5835891509960214);
}
#define TEST_DERIVATIVE_NUM(fct, x, DX_NUM_DIFF, tol) \
{\
double dfdx = fct((DualVector<double,1>((x))+DualVector<double,1>::d())).b[0];\
double dfdx_num = double(-1.)/double(60.) * fct(x-3*dx)\
+ double( 3.)/double(20.) * fct(x-2*dx)\
+ double(-3.)/double(4. ) * fct(x-1*dx)\
+ double( 3.)/double(4. ) * fct(x+1*dx)\
+ double(-3.)/double(20.) * fct(x+2*dx)\
+ double( 1.)/double(60.) * fct(x+3*dx);\
dfdx_num /= dx;\
bool ok = (std::abs(dfdx - dfdx_num) <= tol) ? true : false;\
cout << setw(10) << #fct << "(" << (x) << ") : " << setw(25) << dfdx << " " << setw(25) << dfdx_num << " " << setw(25) << (dfdx-dfdx_num) << "\t" << ok << "\n";\
assert(ok);\
}
void test_derivative_all()
{
cout << "\ntest_derivative_all()\n";
// test all derivatives numerically and check that they add up
double x = 0.5, x2 = 1.5, x3 = -x, dx = 1e-3, tol = 1e-9;
TEST_DERIVATIVE_NUM(cos, x, dx, tol);
TEST_DERIVATIVE_NUM(sin, x, dx, tol);
TEST_DERIVATIVE_NUM(tan, x, dx, tol);
TEST_DERIVATIVE_NUM(sec, x, dx, tol);
TEST_DERIVATIVE_NUM(cot, x, dx, tol);
TEST_DERIVATIVE_NUM(csc, x, dx, tol);
TEST_DERIVATIVE_NUM(acos, x, dx, tol);
TEST_DERIVATIVE_NUM(asin, x, dx, tol);
TEST_DERIVATIVE_NUM(atan, x, dx, tol);
TEST_DERIVATIVE_NUM(asec, x2, dx, tol);
TEST_DERIVATIVE_NUM(acot, x, dx, tol);
TEST_DERIVATIVE_NUM(acsc, x2, dx, tol);
TEST_DERIVATIVE_NUM(cosh, x, dx, tol);
TEST_DERIVATIVE_NUM(sinh, x, dx, tol);
TEST_DERIVATIVE_NUM(tanh, x, dx, tol);
TEST_DERIVATIVE_NUM(sech, x, dx, tol);
TEST_DERIVATIVE_NUM(coth, x, dx, tol);
TEST_DERIVATIVE_NUM(csch, x, dx, tol);
TEST_DERIVATIVE_NUM(acosh, x2, dx, tol);
TEST_DERIVATIVE_NUM(asinh, x, dx, tol);
TEST_DERIVATIVE_NUM(atanh, x, dx, tol);
TEST_DERIVATIVE_NUM(asech, x, dx, tol);
TEST_DERIVATIVE_NUM(acoth, x2, dx, tol);
TEST_DERIVATIVE_NUM(acsch, x, dx, tol);
TEST_DERIVATIVE_NUM(exp, x2, dx, tol);
TEST_DERIVATIVE_NUM(log, x, dx, tol);
TEST_DERIVATIVE_NUM(exp10, x, dx, tol);
TEST_DERIVATIVE_NUM(log10, x, dx, tol);
TEST_DERIVATIVE_NUM(exp2, x2, dx, tol);
TEST_DERIVATIVE_NUM(log2, x, dx, tol);
TEST_DERIVATIVE_NUM(sqrt, x2, dx, tol);
TEST_DERIVATIVE_NUM(sign, x3, dx, tol);
TEST_DERIVATIVE_NUM(sign, x, dx, tol);
TEST_DERIVATIVE_NUM(abs, x3, dx, tol);
TEST_DERIVATIVE_NUM(abs, x, dx, tol);
TEST_DERIVATIVE_NUM(heaviside, x3, dx, tol);
TEST_DERIVATIVE_NUM(heaviside, x, dx, tol);
TEST_DERIVATIVE_NUM(floor, 1.6, dx, tol);
TEST_DERIVATIVE_NUM(ceil, 1.6, dx, tol);
TEST_DERIVATIVE_NUM(round, 1.6, dx, tol);
}
void test_derivative_pow()
{
// test the derivatives of the power function
cout << "\ntest_derivative_pow()\n";
using D = DualVector<double,2>;
double a = 1.5, b = 5.4;
double c = pow(a, b);
double dcda = b*pow(a, b-1);
double dcdb = pow(a, b)*log(a);
D A(a), B(b);
PRINT_VAR(a);
PRINT_VAR(b);
PRINT_VAR(c);
PRINT_VAR(dcda);
PRINT_VAR(dcdb);
PRINT_DUALVECTOR(pow(A, B));
PRINT_DUALVECTOR(pow(A+D::d(0), B+D::d(1)));
D A1(A+D::d(0)), B1(B+D::d(1)), C(0.);
PRINT_DUALVECTOR(exp(B1*log(A1)));
C = pow(A1, B1);
TEST_EQ_DOUBLE(C.a , c);
TEST_EQ_DOUBLE(C.b[0], dcda);
TEST_EQ_DOUBLE(C.b[1], dcdb);
}
void test_derivative_simple()
{
cout << "\ntest_derivative_simple()\n";
/*using D = DualVector<double,1>;
D d = D::d();
D x = 3.5;
D y = f1(x+D::d());
D dy = df1(x);
D ddy = ddf1(x);
PRINT_DUALVECTOR(x);
PRINT_DUALVECTOR(x+d);
PRINT_DUALVECTOR(y);
PRINT_DUALVECTOR(dy);
PRINT_VAR((g1<double, D::VectorT>(x)));
// PRINT_VAR(h1(x.a));
PRINT_VAR(ddy);
assert(y.b[0] == dy.a);//*/
}
void test_derivative_simple_2()
{
cout << "\ntest_derivative_simple_2()\n";
using D = DualVector<double,1>;
D d(0., 1.);
D x = 3.;
D x2 = x+d;
D y = f2(x);
D y2 = f2(x2);
D y3 = f2(D(3.0)+d);
PRINT_DUALVECTOR(x);
PRINT_DUALVECTOR(x+d);
PRINT_DUALVECTOR(y);
PRINT_DUALVECTOR(y2);
PRINT_DUALVECTOR(y3);
assert(y.a == 20.);
assert(y2.a == 20.);
assert(y3.a == 20.);
assert(y2.b[0] == 9.);
assert(y3.b[0] == 9.);
}
void test_derivative_simple_3()
{
// partial derivatives using scalar implementation
cout << "\ntest_derivative_simple_3()\n";
using D = DualVector<double,3>;
D x(2.), y(4.), z(-7.);
D L = f3(x, y, z);
PRINT_DUALVECTOR(L);
PRINT_VAR(f3(x+D::d(0), y+D::d(1), z+D::d(2)).b);
PRINT_VAR(df3x(x.a, y.a, z.a));
PRINT_VAR(df3y(x.a, y.a, z.a));
PRINT_VAR(df3z(x.a, y.a, z.a));
}