#include "AutomaticDifferentiationVector.hpp" #include #include #include using std::cout; using std::cout; using std::setw; #define PRINT_VAR(x) std::cout << #x << "\t= " << std::setprecision(16) << (x) << std::endl #define PRINT_DUALVECTOR(x) std::cout << #x << "\t= " << std::fixed << std::setprecision(4) << std::setw(10) << (x).a << ", " << std::setw(10) << (x).b.transpose() << std::endl #define TEST_EQ_TOL 1e-9 // #define TEST_FUNCTION_ON_DualVector_DOUBLE(fct, x) assert(fct(DualVector(x)).a == fct(x)) // #define TEST_FUNCTION_ON_DualVector_DOUBLE(fct, x) assert(abs((fct(DualVector((x))).a) - (fct((x)))) <= TEST_EQ_TOL) #define TEST_EQ_DOUBLE(a, b) assert(abs((a) - (b)) <= TEST_EQ_TOL) #define TEST_FUNCTION_ON_DualVector_DOUBLE(fct, x) \ if(abs((fct(DualVector((x))).a) - fct((x))) > TEST_EQ_TOL) { \ std::cerr << "Assertion failed at " << __FILE__ << ":" << __LINE__ << " : " << #fct << ">(" << (x) << ") != " << #fct << "(" << (x) << ")" << "\n";\ std::cerr << "Got " << (fct(DualVector(x)).a) << " ; expected " << (fct((x))) << "\n";\ exit(1);\ } template void print_T() { std::cout << __PRETTY_FUNCTION__ << '\n'; } template Scalar f1(const Scalar & x) { return Scalar(5.)*x*x*x + Scalar(3.)*x*x - Scalar(2.)*x + Scalar(4.); } template Scalar df1(const Scalar & x) { return Scalar(15.)*x*x + Scalar(6.)*x - Scalar(2.); } template Scalar ddf1(const Scalar & x) { return Scalar(30.)*x + Scalar(6.); } template Vector g1(Scalar x) { return f1(DualVector(x) + DualVector::d()).b; } template Vector h1(Scalar x) { return g1(DualVector(x) + DualVector::d()).b; } template D f2(D x) { return (x + D(2.0)) * (x + D(1.0)); } template Scalar f3(const Scalar & x, const Scalar & y, const Scalar & z) { return sqrt(z*z+y*y+x*x); } template Scalar df3x(const Scalar & x, const Scalar & y, const Scalar & z) { return x/sqrt(z*z+y*y+x*x); } template Scalar df3y(const Scalar & x, const Scalar & y, const Scalar & z) { return y/sqrt(z*z+y*y+x*x); } template Scalar df3z(const Scalar & x, const Scalar & y, const Scalar & z) { return z/sqrt(z*z+y*y+x*x); } void test_basic(); void test_scalar_functions(); void test_derivative_all(); void test_derivative_pow(); void test_derivative_simple(); void test_derivative_simple_2(); void test_derivative_simple_3(); int main() { test_scalar_functions(); test_basic(); test_derivative_all(); test_derivative_pow(); // test_derivative_simple(); test_derivative_simple_2(); test_derivative_simple_3(); return 0; } void test_basic() { cout << "\ntest_basic()\n"; using D = DualVector; cout.precision(16); double x = 2; double y = 5; D X(x), Y(y); assert(X.a == x); assert(X.b == D::VectorT::Zero()); assert(Y.a == y); assert(Y.b == D::VectorT::Zero()); assert((X+Y).a == (x+y)); assert((X-Y).a == (x-y)); assert((X*Y).a == (x*y)); assert((X/Y).a == (x/y)); assert(-X.a == -x); PRINT_DUALVECTOR(X+Y); PRINT_DUALVECTOR(X-Y); PRINT_DUALVECTOR(X*Y); PRINT_DUALVECTOR(X/Y); assert(D(1., 2.).a == 1.); // assert(D(1., 2.).b == 2.); assert((D(1., 2.)+D(4., 7.)).a == 5.); // assert((D(1., 2.)+D(4., 7.)).b == 9.); // test all the value returned by non-linear functions assert(heaviside(-1.) == 0.); assert(heaviside(0.) == 1.); assert(heaviside(1.) == 1.); assert(sign(-10.) == -1.); assert(sign(0.) == 0.); assert(sign(10.) == 1.); PRINT_VAR(abs(-10.)); PRINT_VAR(abs(10.)); PRINT_VAR(abs(DualVector(-10.)).a); PRINT_VAR(abs(DualVector(10.)).a); PRINT_VAR(abs(DualVector(-1.62)).a); PRINT_VAR(abs(-1.62)); PRINT_VAR(exp10(-3.)); PRINT_VAR(exp10(-2.)); PRINT_VAR(exp10(-1.)); PRINT_VAR(exp10(0.)); PRINT_VAR(exp10(1.)); PRINT_VAR(exp10(2.)); PRINT_VAR(exp10(3.)); x = -1.5; PRINT_VAR((x >= (0.)) ? pow((10.), x) : (1.)/pow((10.), -x)); PRINT_VAR(pow(10., -3.)); PRINT_VAR(pow(10., 3.)); PRINT_VAR(1./pow(10., 3.)); PRINT_VAR(exp2(-3.)); PRINT_VAR(exp2(3.)); TEST_EQ_DOUBLE(exp10(-3.), 1./exp10(3.)); TEST_EQ_DOUBLE(exp10(-3.), 0.001); TEST_EQ_DOUBLE(exp10( 3.), 1000.); PRINT_VAR(atanh(0.62)); PRINT_VAR(atanh(DualVector(0.62)).a); PRINT_VAR(acsc(DualVector(1.62)).a); PRINT_VAR(acsc(1.62)); TEST_EQ_DOUBLE(pow(DualVector(1.62), DualVector(1.5)).a, pow(1.62, 1.5)); // trigonometric functions TEST_FUNCTION_ON_DualVector_DOUBLE(cos, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(sin, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(tan, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(sec, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(cot, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(csc, 1.62); // inverse trigonometric functions TEST_FUNCTION_ON_DualVector_DOUBLE(acos, 0.62); TEST_FUNCTION_ON_DualVector_DOUBLE(asin, 0.62); TEST_FUNCTION_ON_DualVector_DOUBLE(atan, 0.62); TEST_FUNCTION_ON_DualVector_DOUBLE(asec, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(acot, 0.62); TEST_FUNCTION_ON_DualVector_DOUBLE(acsc, 1.62); // exponential functions TEST_FUNCTION_ON_DualVector_DOUBLE(exp, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(log, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(exp10, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(log10, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(exp2, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(log2, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(sqrt, 1.62); // hyperbolic trigonometric functions TEST_FUNCTION_ON_DualVector_DOUBLE(cosh, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(sinh, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(tanh, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(sech, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(coth, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(csch, 1.62); // inverse hyperbolic trigonometric functions TEST_FUNCTION_ON_DualVector_DOUBLE(acosh, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(asinh, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(atanh, 0.62); TEST_FUNCTION_ON_DualVector_DOUBLE(asech, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(acoth, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(acsch, 1.62); // other functions TEST_FUNCTION_ON_DualVector_DOUBLE(sign, -1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(sign, 0.00); TEST_FUNCTION_ON_DualVector_DOUBLE(sign, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(abs, -1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(abs, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(heaviside, -1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(heaviside, 0.00); TEST_FUNCTION_ON_DualVector_DOUBLE(heaviside, 1.62); TEST_FUNCTION_ON_DualVector_DOUBLE(floor, 1.2); TEST_FUNCTION_ON_DualVector_DOUBLE(ceil, 1.2); TEST_FUNCTION_ON_DualVector_DOUBLE(round, 1.2); } void test_scalar_functions() { // test basic scalar functions numerically with values from mathematica cout << "\ntest_scalar_functions()\n"; // trigonometric functions //{-0.04918382191417056,0.998789743470524,-20.30728204110463,-20.33188884233264,-0.04924341908365026,1.001211723025179} TEST_EQ_DOUBLE(cos(1.62), -0.04918382191417056); TEST_EQ_DOUBLE(sin(1.62), 0.998789743470524); TEST_EQ_DOUBLE(tan(1.62), -20.30728204110463); TEST_EQ_DOUBLE(sec(1.62), -20.33188884233264); TEST_EQ_DOUBLE(cot(1.62), -0.04924341908365026); TEST_EQ_DOUBLE(csc(1.62), 1.001211723025179); // inverse trigonometric functions // {0.902053623592525,0.6687427032023717,0.5549957273385867,0.905510600165641,1.01580059945631,0.6652857266292561} TEST_EQ_DOUBLE(acos(0.62), 0.902053623592525); TEST_EQ_DOUBLE(asin(0.62), 0.6687427032023717); TEST_EQ_DOUBLE(atan(0.62), 0.5549957273385867); TEST_EQ_DOUBLE(asec(1.62), 0.905510600165641); TEST_EQ_DOUBLE(acot(0.62), 1.01580059945631); TEST_EQ_DOUBLE(acsc(1.62), 0.6652857266292561); // hyperbolic trigonometric functions // {2.625494507823741,2.427595808740127,0.924624218982788,0.3808806291615117,1.081520448491102,0.4119301888723312} TEST_EQ_DOUBLE(cosh(1.62), 2.625494507823741); TEST_EQ_DOUBLE(sinh(1.62), 2.427595808740127); TEST_EQ_DOUBLE(tanh(1.62), 0.924624218982788); TEST_EQ_DOUBLE(sech(1.62), 0.3808806291615117); TEST_EQ_DOUBLE(coth(1.62), 1.081520448491102); TEST_EQ_DOUBLE(csch(1.62), 0.4119301888723312); // inverse hyperbolic trigonometric functions // {1.062819127408777,1.259535895278778,0.7250050877529992,1.057231115568124,0.7206050593580027,0.5835891509960214} TEST_EQ_DOUBLE(acosh(1.62), 1.062819127408777); TEST_EQ_DOUBLE(asinh(1.62), 1.259535895278778); TEST_EQ_DOUBLE(atanh(0.62), 0.7250050877529992); TEST_EQ_DOUBLE(asech(0.62), 1.057231115568124); TEST_EQ_DOUBLE(acoth(1.62), 0.7206050593580027); TEST_EQ_DOUBLE(acsch(1.62), 0.5835891509960214); } #define TEST_DERIVATIVE_NUM(fct, x, DX_NUM_DIFF, tol); \ {\ double dfdx = fct((DualVector((x))+DualVector::d())).b[0];\ double dfdx_num = ((fct((x)+DX_NUM_DIFF)) - (fct((x)-DX_NUM_DIFF)))/(2*DX_NUM_DIFF);\ bool ok = (abs(dfdx - dfdx_num) <= tol) ? true : false;\ cout << setw(10) << #fct << "(" << (x) << ") : " << setw(25) << dfdx << " " << setw(25) << dfdx_num << " " << setw(25) << (dfdx-dfdx_num) << "\t" << ok << "\n";\ } void test_derivative_all() { cout << "\ntest_derivative_all()\n"; // test all derivatives numerically and check that they add up double x = 0.5, x2 = 1.5, x3 = -x, dx = 1e-6, tol = 1e-9; TEST_DERIVATIVE_NUM(cos, x, dx, tol); TEST_DERIVATIVE_NUM(sin, x, dx, tol); TEST_DERIVATIVE_NUM(tan, x, dx, tol); TEST_DERIVATIVE_NUM(sec, x, dx, tol); TEST_DERIVATIVE_NUM(cot, x, dx, tol); TEST_DERIVATIVE_NUM(csc, x, dx, tol); TEST_DERIVATIVE_NUM(acos, x, dx, tol); TEST_DERIVATIVE_NUM(asin, x, dx, tol); TEST_DERIVATIVE_NUM(atan, x, dx, tol); TEST_DERIVATIVE_NUM(asec, x2, dx, tol); TEST_DERIVATIVE_NUM(acot, x, dx, tol); TEST_DERIVATIVE_NUM(acsc, x2, dx, tol); TEST_DERIVATIVE_NUM(cosh, x, dx, tol); TEST_DERIVATIVE_NUM(sinh, x, dx, tol); TEST_DERIVATIVE_NUM(tanh, x, dx, tol); TEST_DERIVATIVE_NUM(sech, x, dx, tol); TEST_DERIVATIVE_NUM(coth, x, dx, tol); TEST_DERIVATIVE_NUM(csch, x, dx, tol); TEST_DERIVATIVE_NUM(acosh, x2, dx, tol); TEST_DERIVATIVE_NUM(asinh, x, dx, tol); TEST_DERIVATIVE_NUM(atanh, x, dx, tol); TEST_DERIVATIVE_NUM(asech, x, dx, tol); TEST_DERIVATIVE_NUM(acoth, x2, dx, tol); TEST_DERIVATIVE_NUM(acsch, x, dx, tol); TEST_DERIVATIVE_NUM(exp, x2, dx, tol); TEST_DERIVATIVE_NUM(log, x, dx, tol); TEST_DERIVATIVE_NUM(exp10, x, dx, tol); TEST_DERIVATIVE_NUM(log10, x, dx, tol); TEST_DERIVATIVE_NUM(exp2, x2, dx, tol); TEST_DERIVATIVE_NUM(log2, x, dx, tol); TEST_DERIVATIVE_NUM(sqrt, x2, dx, tol); TEST_DERIVATIVE_NUM(sign, x3, dx, tol); TEST_DERIVATIVE_NUM(sign, x, dx, tol); TEST_DERIVATIVE_NUM(abs, x3, dx, tol); TEST_DERIVATIVE_NUM(abs, x, dx, tol); TEST_DERIVATIVE_NUM(heaviside, x3, dx, tol); TEST_DERIVATIVE_NUM(heaviside, x, dx, tol); TEST_DERIVATIVE_NUM(floor, 1.6, dx, tol); TEST_DERIVATIVE_NUM(ceil, 1.6, dx, tol); TEST_DERIVATIVE_NUM(round, 1.6, dx, tol); } void test_derivative_pow() { // test the derivatives of the power function cout << "\ntest_derivative_pow()\n"; using D = DualVector; double a = 1.5, b = 5.4; double c = pow(a, b); double dcda = b*pow(a, b-1); double dcdb = pow(a, b)*log(a); D A(a), B(b); PRINT_VAR(a); PRINT_VAR(b); PRINT_VAR(c); PRINT_VAR(dcda); PRINT_VAR(dcdb); PRINT_DUALVECTOR(pow(A, B)); PRINT_DUALVECTOR(pow(A+D::d(0), B+D::d(1))); D A1(A+D::d(0)), B1(B+D::d(1)), C(0.); PRINT_DUALVECTOR(exp(B1*log(A1))); C = pow(A1, B1); TEST_EQ_DOUBLE(C.a , c); TEST_EQ_DOUBLE(C.b[0], dcda); TEST_EQ_DOUBLE(C.b[1], dcdb); } void test_derivative_simple() { cout << "\ntest_derivative_simple()\n"; /*using D = DualVector; D d = D::d(); D x = 3.5; D y = f1(x+D::d()); D dy = df1(x); D ddy = ddf1(x); PRINT_DUALVECTOR(x); PRINT_DUALVECTOR(x+d); PRINT_DUALVECTOR(y); PRINT_DUALVECTOR(dy); PRINT_VAR((g1(x))); // PRINT_VAR(h1(x.a)); PRINT_VAR(ddy); assert(y.b[0] == dy.a);//*/ } void test_derivative_simple_2() { cout << "\ntest_derivative_simple_2()\n"; using D = DualVector; D d(0., 1.); D x = 3.; D x2 = x+d; D y = f2(x); D y2 = f2(x2); D y3 = f2(D(3.0)+d); PRINT_DUALVECTOR(x); PRINT_DUALVECTOR(x+d); PRINT_DUALVECTOR(y); PRINT_DUALVECTOR(y2); PRINT_DUALVECTOR(y3); assert(y.a == 20.); assert(y2.a == 20.); assert(y3.a == 20.); assert(y2.b[0] == 9.); assert(y3.b[0] == 9.); } void test_derivative_simple_3() { // partial derivatives using scalar implementation cout << "\ntest_derivative_simple_3()\n"; using D = DualVector; D x(2.), y(4.), z(-7.); D L = f3(x, y, z); PRINT_DUALVECTOR(L); PRINT_VAR(f3(x+D::d(0), y+D::d(1), z+D::d(2)).b.transpose()); PRINT_VAR(df3x(x.a, y.a, z.a)); PRINT_VAR(df3y(x.a, y.a, z.a)); PRINT_VAR(df3z(x.a, y.a, z.a)); }