353 lines
11 KiB
C++
Executable file
353 lines
11 KiB
C++
Executable file
#include <catch2/catch.hpp>
|
|
#include "AutomaticDifferentiation.hpp"
|
|
|
|
#define TEST_EQ_TOL 1e-9
|
|
#define EQ_DOUBLE(a, b) (fabs((a) - (b)) <= TEST_EQ_TOL)
|
|
#define EQ_DOUBLE_TOL(a, b, tol) (fabs((a) - (b)) <= (tol))
|
|
#define CHECK_FUNCTION_ON_DUAL_DOUBLE(fct, x) CHECK( EQ_DOUBLE((fct(Dual<double>((x))).a), fct((x))) )
|
|
|
|
#define TEST_DERIVATIVE_NUM(fct, x, DX_NUM_DIFF, tol) \
|
|
CHECK( fabs(fct((Dual<double>((x))+Dual<double>::d())).b - (((fct((x)+DX_NUM_DIFF)) - (fct((x)-DX_NUM_DIFF)))/(2*DX_NUM_DIFF))) <= tol )
|
|
|
|
// Test function for nth derivative
|
|
template<typename Scalar>
|
|
Scalar testFunction1(const Scalar & x)
|
|
{
|
|
return Scalar(1.)/atan(Scalar(1.) - pow(x, Scalar(2.)));
|
|
}
|
|
|
|
template<typename Scalar>
|
|
Scalar dtestFunction1(const Scalar & x)
|
|
{
|
|
return testFunction1(Dual<Scalar>(x) + Dual<Scalar>::d()).b;
|
|
}
|
|
|
|
template<typename Scalar>
|
|
Scalar ddtestFunction1(const Scalar & x)
|
|
{
|
|
return dtestFunction1(Dual<Scalar>(x) + Dual<Scalar>::d()).b;
|
|
}
|
|
|
|
template<typename Scalar>
|
|
Scalar dtestFunction1_sym(const Scalar & x)
|
|
{
|
|
return (Scalar(2.)*x)/((Scalar(1.) + pow(Scalar(1.) - pow(x,Scalar(2.)),Scalar(2.)))*pow(atan(Scalar(1.) - pow(x,Scalar(2.))),Scalar(2.)));
|
|
}
|
|
|
|
template<typename Scalar>
|
|
Scalar ddtestFunction1_sym(const Scalar & x)
|
|
{
|
|
return (Scalar(8.)*pow(x,Scalar(2.)))/(pow(Scalar(1.) + pow(Scalar(1.) - pow(x,Scalar(2.)),Scalar(2.)),Scalar(2.))* pow(atan(Scalar(1.) - pow(x,Scalar(2.))),Scalar(3.))) + (Scalar(8.)*pow(x,Scalar(2.))*(Scalar(1.) - pow(x,Scalar(2.))))/(pow(Scalar(1.) + pow(Scalar(1.) - pow(x,Scalar(2.)),Scalar(2.)),Scalar(2.))* pow(atan(Scalar(1.) - pow(x,Scalar(2.))),Scalar(2.))) + Scalar(2.)/((Scalar(1.) + pow(Scalar(1.) - pow(x,Scalar(2.)),Scalar(2.)))*pow(atan(Scalar(1.) - pow(x,Scalar(2.))),Scalar(2.)));
|
|
}
|
|
|
|
// Length of vector from coordinates
|
|
template<typename Scalar>
|
|
Scalar f3(const Scalar & x, const Scalar & y, const Scalar & z)
|
|
{
|
|
return sqrt(z*z+y*y+x*x);
|
|
}
|
|
|
|
template<typename Scalar>
|
|
Scalar df3x(const Scalar & x, const Scalar & y, const Scalar & z)
|
|
{
|
|
return x/sqrt(z*z+y*y+x*x);
|
|
}
|
|
|
|
template<typename Scalar>
|
|
Scalar df3y(const Scalar & x, const Scalar & y, const Scalar & z)
|
|
{
|
|
return y/sqrt(z*z+y*y+x*x);
|
|
}
|
|
|
|
template<typename Scalar>
|
|
Scalar df3z(const Scalar & x, const Scalar & y, const Scalar & z)
|
|
{
|
|
return z/sqrt(z*z+y*y+x*x);
|
|
}
|
|
|
|
TEST_CASE( "Basic Dual class tests", "[basic]" ) {
|
|
|
|
// For each section, these variables are anew
|
|
double x1 = 1.62, x2 = 0.62, x3 = 3.14, x4 = 2.71;
|
|
using D = Dual<double>;
|
|
|
|
SECTION( "Constructors" ) {
|
|
// REQUIRE( D().a == 0. );
|
|
// REQUIRE( D().b == 0. );
|
|
REQUIRE( D(x1).a == x1 );
|
|
REQUIRE( D(x1).b == 0. );
|
|
REQUIRE( D(x1, x2).a == x1 );
|
|
REQUIRE( D(x1, x2).b == x2 );
|
|
|
|
// copy constructor
|
|
D X(x1, x2);
|
|
D Y(X);
|
|
REQUIRE( X.a == Y.a );
|
|
REQUIRE( X.b == Y.b );
|
|
|
|
// d() function
|
|
REQUIRE( D::d().a == 0. );
|
|
REQUIRE( D::d().b == 1. );
|
|
}
|
|
|
|
SECTION( "Comparison operators" ) {
|
|
D X(x1, x2);
|
|
D Y(x3, x4);
|
|
|
|
// equal
|
|
REQUIRE( (X == X) );
|
|
REQUIRE_FALSE( (X == Y) );
|
|
|
|
// different
|
|
REQUIRE_FALSE( (X != X) );
|
|
REQUIRE( (X != Y) );
|
|
|
|
// lower than
|
|
REQUIRE( (X < Y) == (X.a < Y.a) );
|
|
REQUIRE( (X < X) == (X.a < X.a) );
|
|
REQUIRE( (X <= Y) == (X.a <= Y.a) );
|
|
REQUIRE( (X <= X) == (X.a <= X.a) );
|
|
|
|
// greater than
|
|
REQUIRE( (X > Y) == (X.a > Y.a) );
|
|
REQUIRE( (X > X) == (X.a > X.a) );
|
|
REQUIRE( (X >= Y) == (X.a >= Y.a) );
|
|
REQUIRE( (X >= X) == (X.a >= X.a) );
|
|
}
|
|
|
|
SECTION( "Operators for operations" ) {
|
|
D X(x1, x2);
|
|
D Y(x3, x4);
|
|
|
|
REQUIRE( (X+Y).a == X.a+Y.a );
|
|
REQUIRE( (X-Y).a == X.a-Y.a );
|
|
REQUIRE( (X*Y).a == X.a*Y.a );
|
|
REQUIRE( (X/Y).a == X.a/Y.a );
|
|
|
|
REQUIRE( (X+Y).b == X.b+Y.b );
|
|
REQUIRE( (X-Y).b == X.b-Y.b );
|
|
REQUIRE( (X*Y).b == X.a*Y.b+X.b*Y.a );
|
|
REQUIRE( (X/Y).b == (Y.a*X.b - X.a*Y.b)/(Y.a*Y.a) );
|
|
|
|
// increment and decrement operators
|
|
double aValue = X.a;
|
|
REQUIRE( (++X).a == ++aValue);
|
|
REQUIRE( (--X).a == --aValue);
|
|
REQUIRE( (X++).a == aValue++);
|
|
REQUIRE( (X).a == aValue);// check that it was incremented properly
|
|
REQUIRE( (X--).a == aValue--);
|
|
REQUIRE( (X).a == aValue);// check that it was decremented properly
|
|
}
|
|
}
|
|
|
|
TEST_CASE( "Scalar functions tests", "[scalarFunctions]" ) {
|
|
SECTION( "Scalar functions" ) {
|
|
REQUIRE( fabs(-5.) == 5. );
|
|
REQUIRE( fabs(5.) == 5. );
|
|
|
|
CHECK( EQ_DOUBLE(cos(1.62), -0.04918382191417056) );
|
|
CHECK( EQ_DOUBLE(sin(1.62), 0.998789743470524) );
|
|
CHECK( EQ_DOUBLE(tan(1.62), -20.30728204110463) );
|
|
CHECK( EQ_DOUBLE(sec(1.62), -20.33188884233264) );
|
|
CHECK( EQ_DOUBLE(cot(1.62), -0.04924341908365026) );
|
|
CHECK( EQ_DOUBLE(csc(1.62), 1.001211723025179) );
|
|
|
|
// inverse trigonometric functions
|
|
CHECK( EQ_DOUBLE(acos(0.62), 0.902053623592525) );
|
|
CHECK( EQ_DOUBLE(asin(0.62), 0.6687427032023717) );
|
|
CHECK( EQ_DOUBLE(atan(0.62), 0.5549957273385867) );
|
|
CHECK( EQ_DOUBLE(asec(1.62), 0.905510600165641) );
|
|
CHECK( EQ_DOUBLE(acot(0.62), 1.01580059945631) );
|
|
CHECK( EQ_DOUBLE(acsc(1.62), 0.6652857266292561) );
|
|
|
|
// hyperbolic trigonometric functions
|
|
CHECK( EQ_DOUBLE(cosh(1.62), 2.625494507823741) );
|
|
CHECK( EQ_DOUBLE(sinh(1.62), 2.427595808740127) );
|
|
CHECK( EQ_DOUBLE(tanh(1.62), 0.924624218982788) );
|
|
CHECK( EQ_DOUBLE(sech(1.62), 0.3808806291615117) );
|
|
CHECK( EQ_DOUBLE(coth(1.62), 1.081520448491102) );
|
|
CHECK( EQ_DOUBLE(csch(1.62), 0.4119301888723312) );
|
|
|
|
// inverse hyperbolic trigonometric functions
|
|
CHECK( EQ_DOUBLE(acosh(1.62), 1.062819127408777) );
|
|
CHECK( EQ_DOUBLE(asinh(1.62), 1.259535895278778) );
|
|
CHECK( EQ_DOUBLE(atanh(0.62), 0.7250050877529992) );
|
|
CHECK( EQ_DOUBLE(asech(0.62), 1.057231115568124) );
|
|
CHECK( EQ_DOUBLE(acoth(1.62), 0.7206050593580027) );
|
|
CHECK( EQ_DOUBLE(acsch(1.62), 0.5835891509960214) );
|
|
|
|
// other functions
|
|
CHECK( EQ_DOUBLE(exp10(1.62), pow(10., 1.62)) );
|
|
CHECK( EQ_DOUBLE(sign(1.62), 1.) );
|
|
CHECK( EQ_DOUBLE(sign(-1.62), -1.) );
|
|
CHECK( EQ_DOUBLE(sign(0.), 0.) );
|
|
CHECK( EQ_DOUBLE(heaviside(-1.62), 0.) );
|
|
CHECK( EQ_DOUBLE(heaviside(0.), 1.) );
|
|
CHECK( EQ_DOUBLE(heaviside(1.62), 1.) );
|
|
}
|
|
}
|
|
|
|
TEST_CASE( "Functions on dual numbers", "[FunctionsOnDualNumbers]" ) {
|
|
SECTION( "Functions on Dual numbers" ) {
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(cos, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(sin, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(tan, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(sec, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(cot, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(csc, 1.62);
|
|
|
|
// inverse trigonometric functions
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(acos, 0.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(asin, 0.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(atan, 0.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(asec, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(acot, 0.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(acsc, 1.62);
|
|
|
|
// exponential functions
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(exp, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(log, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(exp10, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(log10, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(exp2, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(log2, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(sqrt, 1.62);
|
|
|
|
// hyperbolic trigonometric functions
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(cosh, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(sinh, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(tanh, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(sech, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(coth, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(csch, 1.62);
|
|
|
|
// inverse hyperbolic trigonometric functions
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(acosh, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(asinh, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(atanh, 0.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(asech, 0.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(acoth, 1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(acsch, 1.62);
|
|
|
|
// other functions
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(sign, -1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(sign, 0.00);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(sign, 1.62);
|
|
|
|
CHECK( abs(Dual<double>(-1.62)) == fabs(-1.62) );
|
|
CHECK( fabs(Dual<double>(-1.62)) == fabs(-1.62) );
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(fabs, -1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(fabs, 1.62);
|
|
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(heaviside, -1.62);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(heaviside, 0.00);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(heaviside, 1.62);
|
|
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(floor, 1.2);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(ceil, 1.2);
|
|
CHECK_FUNCTION_ON_DUAL_DOUBLE(round, 1.2);
|
|
}
|
|
|
|
SECTION( "Function derivatives checked numerically" ) {
|
|
double x = 0.5, x2 = 1.5, x3 = -x, dx = 1e-6, tol = 1e-9;
|
|
TEST_DERIVATIVE_NUM(cos, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(sin, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(tan, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(sec, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(cot, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(csc, x, dx, tol);
|
|
|
|
TEST_DERIVATIVE_NUM(acos, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(asin, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(atan, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(asec, x2, dx, tol);
|
|
TEST_DERIVATIVE_NUM(acot, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(acsc, x2, dx, tol);
|
|
|
|
TEST_DERIVATIVE_NUM(cosh, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(sinh, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(tanh, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(sech, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(coth, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(csch, x, dx, tol);
|
|
|
|
TEST_DERIVATIVE_NUM(acosh, x2, dx, tol);
|
|
TEST_DERIVATIVE_NUM(asinh, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(atanh, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(asech, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(acoth, x2, dx, tol);
|
|
TEST_DERIVATIVE_NUM(acsch, x, dx, tol);
|
|
|
|
TEST_DERIVATIVE_NUM(exp, x2, dx, tol);
|
|
TEST_DERIVATIVE_NUM(log, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(exp10, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(log10, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(exp2, x2, dx, tol);
|
|
TEST_DERIVATIVE_NUM(log2, x, dx, tol);
|
|
|
|
TEST_DERIVATIVE_NUM(sqrt, x2, dx, tol);
|
|
TEST_DERIVATIVE_NUM(sign, x3, dx, tol);
|
|
TEST_DERIVATIVE_NUM(sign, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(fabs, x3, dx, tol);
|
|
TEST_DERIVATIVE_NUM(fabs, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(heaviside, x3, dx, tol);
|
|
TEST_DERIVATIVE_NUM(heaviside, x, dx, tol);
|
|
TEST_DERIVATIVE_NUM(floor, 1.6, dx, tol);
|
|
TEST_DERIVATIVE_NUM(ceil, 1.6, dx, tol);
|
|
TEST_DERIVATIVE_NUM(round, 1.6, dx, tol);
|
|
}
|
|
}
|
|
|
|
TEST_CASE( "Nth derivative", "[NthDerivative]" ) {
|
|
double x = 0.7;
|
|
double dx = 1e-6;
|
|
double fx = testFunction1(x);
|
|
double dfdx = (testFunction1(x+dx) - testFunction1(x-dx))/(2*dx);
|
|
double d2fdx2 = (testFunction1(x-dx) - 2*fx + testFunction1(x+dx))/(dx*dx);
|
|
|
|
CHECK( testFunction1(Dual<double>(x)).a == fx );
|
|
CHECK( EQ_DOUBLE_TOL(dtestFunction1_sym(x), dfdx, 1e-6) );
|
|
CHECK( EQ_DOUBLE_TOL(ddtestFunction1_sym(x), d2fdx2, 1e-4) );
|
|
|
|
CHECK( testFunction1(Dual<double>(x) + Dual<double>::d()).b == dtestFunction1_sym(x) );
|
|
CHECK( dtestFunction1(x) == dtestFunction1_sym(x) );
|
|
CHECK( EQ_DOUBLE(ddtestFunction1(x), ddtestFunction1_sym(x)) );
|
|
}
|
|
|
|
TEST_CASE( "Partial derivatives", "[PartialDerivatives]" ) {
|
|
double x = 0.7, y = -2., z = 1.5;
|
|
|
|
CHECK( f3(Dual<double>(x), Dual<double>(y), Dual<double>(z)) == f3(x,y,z) );
|
|
CHECK( f3(Dual<double>(x)+Dual<double>::d(), Dual<double>(y), Dual<double>(z)).b == df3x(x,y,z) );
|
|
CHECK( f3(Dual<double>(x), Dual<double>(y)+Dual<double>::d(), Dual<double>(z)).b == df3y(x,y,z) );
|
|
CHECK( f3(Dual<double>(x), Dual<double>(y), Dual<double>(z)+Dual<double>::d()).b == df3z(x,y,z) );
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|