#include #include "AutomaticDifferentiation.hpp" #define TEST_EQ_TOL 1e-9 #define EQ_DOUBLE(a, b) (fabs((a) - (b)) <= TEST_EQ_TOL) #define EQ_DOUBLE_TOL(a, b, tol) (fabs((a) - (b)) <= (tol)) #define CHECK_FUNCTION_ON_DUAL_DOUBLE(fct, x) CHECK( EQ_DOUBLE((fct(Dual((x))).a), fct((x))) ) #define TEST_DERIVATIVE_NUM(fct, x, DX_NUM_DIFF, tol) \ CHECK( fabs(fct((Dual((x))+Dual::d())).b - (((fct((x)+DX_NUM_DIFF)) - (fct((x)-DX_NUM_DIFF)))/(2*DX_NUM_DIFF))) <= tol ) // Test function for nth derivative template Scalar testFunction1(const Scalar & x) { return Scalar(1.)/atan(Scalar(1.) - pow(x, Scalar(2.))); } template Scalar dtestFunction1(const Scalar & x) { return testFunction1(Dual(x) + Dual::d()).b; } template Scalar ddtestFunction1(const Scalar & x) { return dtestFunction1(Dual(x) + Dual::d()).b; } template Scalar dtestFunction1_sym(const Scalar & x) { return (Scalar(2.)*x)/((Scalar(1.) + pow(Scalar(1.) - pow(x,Scalar(2.)),Scalar(2.)))*pow(atan(Scalar(1.) - pow(x,Scalar(2.))),Scalar(2.))); } template Scalar ddtestFunction1_sym(const Scalar & x) { return (Scalar(8.)*pow(x,Scalar(2.)))/(pow(Scalar(1.) + pow(Scalar(1.) - pow(x,Scalar(2.)),Scalar(2.)),Scalar(2.))* pow(atan(Scalar(1.) - pow(x,Scalar(2.))),Scalar(3.))) + (Scalar(8.)*pow(x,Scalar(2.))*(Scalar(1.) - pow(x,Scalar(2.))))/(pow(Scalar(1.) + pow(Scalar(1.) - pow(x,Scalar(2.)),Scalar(2.)),Scalar(2.))* pow(atan(Scalar(1.) - pow(x,Scalar(2.))),Scalar(2.))) + Scalar(2.)/((Scalar(1.) + pow(Scalar(1.) - pow(x,Scalar(2.)),Scalar(2.)))*pow(atan(Scalar(1.) - pow(x,Scalar(2.))),Scalar(2.))); } // Length of vector from coordinates template Scalar f3(const Scalar & x, const Scalar & y, const Scalar & z) { return sqrt(z*z+y*y+x*x); } template Scalar df3x(const Scalar & x, const Scalar & y, const Scalar & z) { return x/sqrt(z*z+y*y+x*x); } template Scalar df3y(const Scalar & x, const Scalar & y, const Scalar & z) { return y/sqrt(z*z+y*y+x*x); } template Scalar df3z(const Scalar & x, const Scalar & y, const Scalar & z) { return z/sqrt(z*z+y*y+x*x); } TEST_CASE( "Basic Dual class tests", "[basic]" ) { // For each section, these variables are anew double x1 = 1.62, x2 = 0.62, x3 = 3.14, x4 = 2.71; using D = Dual; SECTION( "Constructors" ) { // REQUIRE( D().a == 0. ); // REQUIRE( D().b == 0. ); REQUIRE( D(x1).a == x1 ); REQUIRE( D(x1).b == 0. ); REQUIRE( D(x1, x2).a == x1 ); REQUIRE( D(x1, x2).b == x2 ); // copy constructor D X(x1, x2); D Y(X); REQUIRE( X.a == Y.a ); REQUIRE( X.b == Y.b ); // d() function REQUIRE( D::d().a == 0. ); REQUIRE( D::d().b == 1. ); } SECTION( "Comparison operators" ) { D X(x1, x2); D Y(x3, x4); // equal REQUIRE( (X == X) ); REQUIRE_FALSE( (X == Y) ); // different REQUIRE_FALSE( (X != X) ); REQUIRE( (X != Y) ); // lower than REQUIRE( (X < Y) == (X.a < Y.a) ); REQUIRE( (X < X) == (X.a < X.a) ); REQUIRE( (X <= Y) == (X.a <= Y.a) ); REQUIRE( (X <= X) == (X.a <= X.a) ); // greater than REQUIRE( (X > Y) == (X.a > Y.a) ); REQUIRE( (X > X) == (X.a > X.a) ); REQUIRE( (X >= Y) == (X.a >= Y.a) ); REQUIRE( (X >= X) == (X.a >= X.a) ); } SECTION( "Operators for operations" ) { D X(x1, x2); D Y(x3, x4); REQUIRE( (X+Y).a == X.a+Y.a ); REQUIRE( (X-Y).a == X.a-Y.a ); REQUIRE( (X*Y).a == X.a*Y.a ); REQUIRE( (X/Y).a == X.a/Y.a ); REQUIRE( (X+Y).b == X.b+Y.b ); REQUIRE( (X-Y).b == X.b-Y.b ); REQUIRE( (X*Y).b == X.a*Y.b+X.b*Y.a ); REQUIRE( (X/Y).b == (Y.a*X.b - X.a*Y.b)/(Y.a*Y.a) ); // increment and decrement operators double aValue = X.a; REQUIRE( (++X).a == ++aValue); REQUIRE( (--X).a == --aValue); REQUIRE( (X++).a == aValue++); REQUIRE( (X).a == aValue);// check that it was incremented properly REQUIRE( (X--).a == aValue--); REQUIRE( (X).a == aValue);// check that it was decremented properly } } TEST_CASE( "Scalar functions tests", "[scalarFunctions]" ) { SECTION( "Scalar functions" ) { REQUIRE( fabs(-5.) == 5. ); REQUIRE( fabs(5.) == 5. ); CHECK( EQ_DOUBLE(cos(1.62), -0.04918382191417056) ); CHECK( EQ_DOUBLE(sin(1.62), 0.998789743470524) ); CHECK( EQ_DOUBLE(tan(1.62), -20.30728204110463) ); CHECK( EQ_DOUBLE(sec(1.62), -20.33188884233264) ); CHECK( EQ_DOUBLE(cot(1.62), -0.04924341908365026) ); CHECK( EQ_DOUBLE(csc(1.62), 1.001211723025179) ); // inverse trigonometric functions CHECK( EQ_DOUBLE(acos(0.62), 0.902053623592525) ); CHECK( EQ_DOUBLE(asin(0.62), 0.6687427032023717) ); CHECK( EQ_DOUBLE(atan(0.62), 0.5549957273385867) ); CHECK( EQ_DOUBLE(asec(1.62), 0.905510600165641) ); CHECK( EQ_DOUBLE(acot(0.62), 1.01580059945631) ); CHECK( EQ_DOUBLE(acsc(1.62), 0.6652857266292561) ); // hyperbolic trigonometric functions CHECK( EQ_DOUBLE(cosh(1.62), 2.625494507823741) ); CHECK( EQ_DOUBLE(sinh(1.62), 2.427595808740127) ); CHECK( EQ_DOUBLE(tanh(1.62), 0.924624218982788) ); CHECK( EQ_DOUBLE(sech(1.62), 0.3808806291615117) ); CHECK( EQ_DOUBLE(coth(1.62), 1.081520448491102) ); CHECK( EQ_DOUBLE(csch(1.62), 0.4119301888723312) ); // inverse hyperbolic trigonometric functions CHECK( EQ_DOUBLE(acosh(1.62), 1.062819127408777) ); CHECK( EQ_DOUBLE(asinh(1.62), 1.259535895278778) ); CHECK( EQ_DOUBLE(atanh(0.62), 0.7250050877529992) ); CHECK( EQ_DOUBLE(asech(0.62), 1.057231115568124) ); CHECK( EQ_DOUBLE(acoth(1.62), 0.7206050593580027) ); CHECK( EQ_DOUBLE(acsch(1.62), 0.5835891509960214) ); // other functions CHECK( EQ_DOUBLE(exp10(1.62), pow(10., 1.62)) ); CHECK( EQ_DOUBLE(sign(1.62), 1.) ); CHECK( EQ_DOUBLE(sign(-1.62), -1.) ); CHECK( EQ_DOUBLE(sign(0.), 0.) ); CHECK( EQ_DOUBLE(heaviside(-1.62), 0.) ); CHECK( EQ_DOUBLE(heaviside(0.), 1.) ); CHECK( EQ_DOUBLE(heaviside(1.62), 1.) ); } } TEST_CASE( "Functions on dual numbers", "[FunctionsOnDualNumbers]" ) { SECTION( "Functions on Dual numbers" ) { CHECK_FUNCTION_ON_DUAL_DOUBLE(cos, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(sin, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(tan, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(sec, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(cot, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(csc, 1.62); // inverse trigonometric functions CHECK_FUNCTION_ON_DUAL_DOUBLE(acos, 0.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(asin, 0.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(atan, 0.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(asec, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(acot, 0.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(acsc, 1.62); // exponential functions CHECK_FUNCTION_ON_DUAL_DOUBLE(exp, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(log, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(exp10, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(log10, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(exp2, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(log2, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(sqrt, 1.62); // hyperbolic trigonometric functions CHECK_FUNCTION_ON_DUAL_DOUBLE(cosh, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(sinh, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(tanh, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(sech, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(coth, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(csch, 1.62); // inverse hyperbolic trigonometric functions CHECK_FUNCTION_ON_DUAL_DOUBLE(acosh, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(asinh, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(atanh, 0.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(asech, 0.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(acoth, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(acsch, 1.62); // other functions CHECK_FUNCTION_ON_DUAL_DOUBLE(sign, -1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(sign, 0.00); CHECK_FUNCTION_ON_DUAL_DOUBLE(sign, 1.62); CHECK( abs(Dual(-1.62)) == fabs(-1.62) ); CHECK( fabs(Dual(-1.62)) == fabs(-1.62) ); CHECK_FUNCTION_ON_DUAL_DOUBLE(fabs, -1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(fabs, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(heaviside, -1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(heaviside, 0.00); CHECK_FUNCTION_ON_DUAL_DOUBLE(heaviside, 1.62); CHECK_FUNCTION_ON_DUAL_DOUBLE(floor, 1.2); CHECK_FUNCTION_ON_DUAL_DOUBLE(ceil, 1.2); CHECK_FUNCTION_ON_DUAL_DOUBLE(round, 1.2); } SECTION( "Function derivatives checked numerically" ) { double x = 0.5, x2 = 1.5, x3 = -x, dx = 1e-6, tol = 1e-9; TEST_DERIVATIVE_NUM(cos, x, dx, tol); TEST_DERIVATIVE_NUM(sin, x, dx, tol); TEST_DERIVATIVE_NUM(tan, x, dx, tol); TEST_DERIVATIVE_NUM(sec, x, dx, tol); TEST_DERIVATIVE_NUM(cot, x, dx, tol); TEST_DERIVATIVE_NUM(csc, x, dx, tol); TEST_DERIVATIVE_NUM(acos, x, dx, tol); TEST_DERIVATIVE_NUM(asin, x, dx, tol); TEST_DERIVATIVE_NUM(atan, x, dx, tol); TEST_DERIVATIVE_NUM(asec, x2, dx, tol); TEST_DERIVATIVE_NUM(acot, x, dx, tol); TEST_DERIVATIVE_NUM(acsc, x2, dx, tol); TEST_DERIVATIVE_NUM(cosh, x, dx, tol); TEST_DERIVATIVE_NUM(sinh, x, dx, tol); TEST_DERIVATIVE_NUM(tanh, x, dx, tol); TEST_DERIVATIVE_NUM(sech, x, dx, tol); TEST_DERIVATIVE_NUM(coth, x, dx, tol); TEST_DERIVATIVE_NUM(csch, x, dx, tol); TEST_DERIVATIVE_NUM(acosh, x2, dx, tol); TEST_DERIVATIVE_NUM(asinh, x, dx, tol); TEST_DERIVATIVE_NUM(atanh, x, dx, tol); TEST_DERIVATIVE_NUM(asech, x, dx, tol); TEST_DERIVATIVE_NUM(acoth, x2, dx, tol); TEST_DERIVATIVE_NUM(acsch, x, dx, tol); TEST_DERIVATIVE_NUM(exp, x2, dx, tol); TEST_DERIVATIVE_NUM(log, x, dx, tol); TEST_DERIVATIVE_NUM(exp10, x, dx, tol); TEST_DERIVATIVE_NUM(log10, x, dx, tol); TEST_DERIVATIVE_NUM(exp2, x2, dx, tol); TEST_DERIVATIVE_NUM(log2, x, dx, tol); TEST_DERIVATIVE_NUM(sqrt, x2, dx, tol); TEST_DERIVATIVE_NUM(sign, x3, dx, tol); TEST_DERIVATIVE_NUM(sign, x, dx, tol); TEST_DERIVATIVE_NUM(fabs, x3, dx, tol); TEST_DERIVATIVE_NUM(fabs, x, dx, tol); TEST_DERIVATIVE_NUM(heaviside, x3, dx, tol); TEST_DERIVATIVE_NUM(heaviside, x, dx, tol); TEST_DERIVATIVE_NUM(floor, 1.6, dx, tol); TEST_DERIVATIVE_NUM(ceil, 1.6, dx, tol); TEST_DERIVATIVE_NUM(round, 1.6, dx, tol); } } TEST_CASE( "Nth derivative", "[NthDerivative]" ) { double x = 0.7; double dx = 1e-6; double fx = testFunction1(x); double dfdx = (testFunction1(x+dx) - testFunction1(x-dx))/(2*dx); double d2fdx2 = (testFunction1(x-dx) - 2*fx + testFunction1(x+dx))/(dx*dx); CHECK( testFunction1(Dual(x)).a == fx ); CHECK( EQ_DOUBLE_TOL(dtestFunction1_sym(x), dfdx, 1e-6) ); CHECK( EQ_DOUBLE_TOL(ddtestFunction1_sym(x), d2fdx2, 1e-4) ); CHECK( testFunction1(Dual(x) + Dual::d()).b == dtestFunction1_sym(x) ); CHECK( dtestFunction1(x) == dtestFunction1_sym(x) ); CHECK( EQ_DOUBLE(ddtestFunction1(x), ddtestFunction1_sym(x)) ); } TEST_CASE( "Partial derivatives", "[PartialDerivatives]" ) { double x = 0.7, y = -2., z = 1.5; CHECK( f3(Dual(x), Dual(y), Dual(z)) == f3(x,y,z) ); CHECK( f3(Dual(x)+Dual::d(), Dual(y), Dual(z)).b == df3x(x,y,z) ); CHECK( f3(Dual(x), Dual(y)+Dual::d(), Dual(z)).b == df3y(x,y,z) ); CHECK( f3(Dual(x), Dual(y), Dual(z)+Dual::d()).b == df3z(x,y,z) ); }