Auto test of TOMS748. Working version of 4.1. Auto-Bracketing done. Bracket and solve routine done.

This commit is contained in:
Jérôme 2019-04-04 10:36:27 +02:00
parent d83a997804
commit c3f5f42333
5 changed files with 379 additions and 35 deletions

2
.gitignore vendored
View file

@ -1,5 +1,7 @@
# executable
run
cpp/test/run
cpp/run
*.pdf

View file

@ -3,7 +3,7 @@ C = clang
COMMON_FLAGS = -Wall -MMD
C_FLAGS = $(COMMON_FLAGS)
CC = clang++
CC_FLAGS = $(COMMON_FLAGS) -std=c++17
CC_FLAGS = $(COMMON_FLAGS) -std=c++17 -O3
LD_FLAGS =
INCLUDES =

View file

@ -7,9 +7,12 @@ namespace TOMS748
{
namespace internal
{
template<typename T> T sureAbs(const T & x) { return (x < T(0)) ? -x : x; }
// checks that none of the values are the same
// returns true if at least two values are identical
template<typename T> bool checkTwoValuesIdentical(T fa, T fb, T fc, T fd)
template<typename T>
bool checkTwoValuesIdentical(T fa, T fb, T fc, T fd)
{
bool same = false;
same |= fa == fb;
@ -32,7 +35,8 @@ namespace TOMS748
// standard bracketing routine
// returns ahat, bhat, d
// d is a point outside the new interval
template<typename T> std::tuple<T,T,T,T,T,T> bracket(T a, T b, T c, T fa, T fb, T fc)
template<typename T> std::tuple<T,T,T,T,T,T>
bracket(T a, T b, T c, T fa, T fb, T fc)
{
if(fa*fc < T(0))
return std::make_tuple(a, c, b, fa, fc, fb);
@ -44,10 +48,11 @@ namespace TOMS748
// checks if f(c) is close enough to 0 and returns ok = true if it is the case
// returns ahat, bhat, d
// d is a point outside the new interval
template<typename T> std::tuple<T,T,T,T,T,T,bool> bracketAndCheckConvergence(T a, T b, T c, T fa, T fb, T fc, T tol)
template<typename T> std::tuple<T,T,T,T,T,T,bool>
bracketAndCheckConvergence(T a, T b, T c, T fa, T fb, T fc, T tol)
{
bool ok = false;
if(sureAbs(fc) < tol)
if(internal::sureAbs(fc) < tol)
ok = true;
if(fa*fc < T(0))
return std::make_tuple(a, c, b, fa, fc, fb, ok);
@ -57,7 +62,8 @@ namespace TOMS748
// finds an approximate solution to the quadratic P(x) = fa + f[a,b]*(x-a) + f[a,b,d]*(x-a)(x-b)
// with f[a,b] = fbracket1(a,b) and f[a,b,d] = fbracket2(a,b,d)
template<typename T, int k> T NewtonQuadratic(T a, T b, T d, T fa, T fb, T fd)
template<typename T, int k>
T NewtonQuadratic(T a, T b, T d, T fa, T fb, T fd)
{
T r;
T A = fbracket2(a,b,d,fa,fb,fd);
@ -75,7 +81,8 @@ namespace TOMS748
}
// Inverse cubic interpolation evaluated at 0 (modified Aitken-Neville interpolation)
template<typename T> T ipzero(T a, T b, T c, T d, T fa, T fb, T fc, T fd)
template<typename T>
T ipzero(T a, T b, T c, T d, T fa, T fb, T fc, T fd)
{
T Q11 = (c-d)*fc/(fd-fc);
T Q21 = (b-c)*fb/(fc-fb);
@ -88,37 +95,75 @@ namespace TOMS748
T Q33 = (D32-Q22)*fa/(fd-fa);
return a + Q31 + Q32 + Q33;
}
/// Sorts the pairs (abs(a), fa) and (abs(b), fb) in ascending order.
template<typename T>
void swap_abs_ab_fafb_if_not_ascending(T & a, T & b, T & fa, T & fb)
{
if(internal::sureAbs(b) < internal::sureAbs(a))
{
// swap a and b
T temp = a;
a = b;
b = temp;
// swap fa and fb
temp = fa;
fa = fb;
fb = temp;
}
}
/// Sorts the pairs (a, fa) and (b, fb) in ascending order.
template<typename T>
void swap_ab_fafb_if_not_ascending(T & a, T & b, T & fa, T & fb)
{
if(b < a)
{
// swap a and b
T temp = a;
a = b;
b = temp;
// swap fa and fb
temp = fa;
fa = fb;
fb = temp;
}
}
}// namespace internal
/// Algorithm 4.1 from TOMS748 of robust root-solving.
/// Use this version if f(a) and f(b) have already been computed.
template<typename Func, typename T> std::tuple<T,T,bool> TOMS748_solve1(Func f, T a, T b, T fa, T fb, T tol, int Nmax = 1000)
/// Returns x, f(x), and a boolean indicating if the function converged or not (true if converged).
template<typename Func, typename T>
std::tuple<T,T,bool> TOMS748_solve1(Func f, T a, T b, T fa, T fb, T tol, unsigned int Nmax = 1000)
{
using namespace internal;
T c, d, e, u, dbar, dhat, fc, fd, fe, fu, fdbar, fdhat;
T mu = 0.5;
bool ok;
c = a - fa/fbracket1(a, b, fa, fb); // 4.1.1 secant method
c = a - fa/fbracket1(a, b, fa, fb); // 4.1.1 secant method
fc = f(c);
std::tie(a, b, d, fa, fb, fd, ok) = bracketAndCheckConvergence(a, b, c, fa, fb, fc, tol); // 4.1.2
if(ok) { return std::make_tuple(c, fc, true); }
e = d;
fe = fd;
// ---
for(int n = 2 ; n < Nmax ; n++) // 4.1.3
for(unsigned int n = 2 ; n < Nmax ; n++) // 4.1.3
{
if(n == 2 || checkTwoValuesIdentical(fa, fb, fd, fe))
c = NewtonQuadratic<T,2>(a, b, d, fa, fb, fd);
else
{
c = ipzero(a, b, d, e, fa, fb, fd, fe);
if((c-a)*(c-b) >= T(0))
c = NewtonQuadratic<T,2>(a, b, d, fa, fb, fd);
}
// ---
fc = f(c);
std::tie(a, b, dbar, fa, fb, fdbar, ok) = bracketAndCheckConvergence(a, b, c, fa, fb, fc, tol); // 4.1.4
if(ok) { return std::make_tuple(c, fc, true); }
// ---
if(fabs(fa) < fabs(fb)) // 4.1.5
if(internal::sureAbs(fa) < internal::sureAbs(fb)) // 4.1.5
{
u = a;
fu = fa;
@ -129,16 +174,16 @@ namespace TOMS748
fu = fb;
}
// ---
c = u - 2*fu/fbracket1(a, b, fa, fb); // 4.1.6
c = u - 2*fu/fbracket1(a, b, fa, fb); // 4.1.6
// ---
if(fabs(c - u) > 0.5*(b - a)) // 4.1.7
if(internal::sureAbs(c - u) > 0.5*(b - a)) // 4.1.7
c = 0.5*(b + a);
// ---
fc = f(c);
std::tie(a, b, dhat, fa, fb, fdhat, ok) = bracketAndCheckConvergence(a, b, c, fa, fb, fc, tol); // 4.1.8
if(ok) { return std::make_tuple(c, fc, true); }
// ---
if(b - a < mu*(b - a)) // 4.1.9
if(b - a < mu*(b - a)) // 4.1.9
{
d = dhat;
e = dbar;
@ -158,10 +203,91 @@ namespace TOMS748
return std::make_tuple(c, fc, false);// no solution found, return last estimate
}
template<typename Func, typename T> std::tuple<T,T,bool> TOMS748_solve1(Func f, T a, T b, T tol, int Nmax = 1000)
/// Algorithm 4.1 from TOMS748 of robust root-solving.
/// Use this version if f(a) and f(b) have NOT already been computed.
/// Returns x, f(x), and a boolean indicating if the function converged or not (true if converged).
template<typename Func, typename T>
std::tuple<T,T,bool> TOMS748_solve1(Func f, T a, T b, T tol, unsigned int Nmax = 1000)
{
return TOMS748_solve1(f, a, b, f(a), f(b), tol, Nmax);
}
/// Finds a bracket [a b] that encloses a root of f : f(a)*f(b) < 0, starting from the initial guess x.
/// A typical value for the factor is 2.
/// The initial guess must be on the correct side of the real line : the algorithm will never cross the zero during the search.
/// The factor is increased every 10 iterations in order to speed up convergence for when the root is orders of magnitude away from the initial guess.
/// If the bracket returned by the function is too wide, try reducing the factor (while always keeping it > 1).
///
/// The function returns a tuple of the following values :
/// - T a : lower bound of the interval.
/// - T b : upper bound of the interval.
/// - T fa : function value at lower bound of the interval.
/// - T fb : function value at upper bound of the interval.
/// - bool ok : indicates whether the search was successful or not.
template<typename Func, typename T>
std::tuple<T,T,T,T,bool> findEnclosingBracket(Func f, T x, T factor, unsigned int Nmax = 100)
{
T a, b, fa, fb;
unsigned int i;
a = x;
b = x*factor; // try searching in the positive direction
fa = f(a);
fb = f(b);
internal::swap_abs_ab_fafb_if_not_ascending(a, b, fa, fb);
if(fa*fb <= T(0)) // if the original bracket is already enclosing the solution
{
internal::swap_ab_fafb_if_not_ascending(a, b, fa, fb);
return std::make_tuple(a, b, fa, fb, true);
}
if(internal::sureAbs(fa) < internal::sureAbs(fb)) // if fa is closer, search in the other direction
{
a = x/factor;
b = x;
fb = fa; // f(b) = f(x)
fa = f(a);
internal::swap_abs_ab_fafb_if_not_ascending(a, b, fa, fb);
}
i = 0;
while(i < Nmax && fa*fb > T(0))
{
if(internal::sureAbs(fa) < internal::sureAbs(fb)) // fa is closer to 0 than fb, extend the interval in the a direction (reducing it)
{
a = a / factor;
fa = f(a);
}
else // fb is closer to 0 than fa, extend the interval in the b direction (augmenting it)
{
b = b * factor;
fb = f(b);
}
i = i+1;
if(i % 10 == 0) // every 10 iterations, bump up the factor to speed up convergence
factor = factor * 5;
}
internal::swap_ab_fafb_if_not_ascending(a, b, fa, fb);
return std::make_tuple(a, b, fa, fb, static_cast<bool>(fa*fb <= T(0)));
}
/// Finds a bracket that encloses the solution and then calls TOMS748_solve1 to solve for the root of the function in the bracket.
/// Returns x, f(x), and a boolean indicating if the function converged or not (true if converged).
template<typename Func, typename T>
std::tuple<T,T,bool> bracket_and_solve(Func f, T x, T tol, T factor = T(2), unsigned int Nmax = 1000)
{
T a, b, fa, fb;
bool ok;
std::tie(a, b, fa, fb, ok) = findEnclosingBracket(f, x, factor, Nmax);
if(ok)
return TOMS748_solve1(f, a, b, fa, fb, tol, Nmax);
else
{
internal::swap_abs_ab_fafb_if_not_ascending(fa, fb, a, b);
return std::make_tuple(a, fa, false);
}
}
}// namespace TOMS748
#endif

View file

@ -2,33 +2,249 @@
#include <iostream>
#include "../utils.hpp"
#include "../TOMS748.hpp"
#include "utils_test.hpp"
#define print(x) PRINT_VAR(x);
#define printvec(x) PRINT_VEC(x);
#define printstr(x) PRINT_STR(x);
TEST_CASE( "Test case 1", "[test1]" )
template<typename T> T f(T E) { return E - 0.5*sin(E) - 0.3; }
template<typename T> T f2(T x) { return exp(x)-T(2); }
template<typename T> T f3(T x) { return (x*x)-T(1); }
using namespace TOMS748;
using namespace TOMS748::internal;
TEST_CASE( "sureAbs", "[TOMS748]" )
{
REQUIRE(TOMS748::internal::sureAbs(-1) == 1);
REQUIRE(TOMS748::internal::sureAbs(-1.) == 1.);
}
TEST_CASE( "checkTwoValuesIdentical", "[TOMS748]" )
{
std::cout.precision(16);
SECTION( "Check almost equal" ) {
CHECK(check_almost_equal(1.00, 1.01, 0.1));
CHECK(check_almost_equal(1.00, 3.01, 0.01));
CHECK(check_almost_equal(1.00, 1.01, 0.001));
REQUIRE(true);
}
SECTION( "Check almost equal on vectors" ) {
unsigned int N = 5;
std::vector<double> v1(N), v2(N);
for (size_t i = 0; i < N; i++) {
v1[i] = double(i);
v2[i] = v1[i] + 0.0001;
}
CHECK(check_almost_equalV(v1, v2, 0.001));
CHECK(check_almost_equalV(v1, v2, 0.0001));
CHECK(check_almost_equalV(v1, v2, 0.00001));
SECTION( "checkTwoValuesIdentical" ) {
REQUIRE(checkTwoValuesIdentical(1,2,3,4) == false);
REQUIRE(checkTwoValuesIdentical(1,2,3,1) == true);
REQUIRE(checkTwoValuesIdentical(1,1,3,4) == true);
REQUIRE(checkTwoValuesIdentical(1,2,2,4) == true);
REQUIRE(checkTwoValuesIdentical(1,2,3,3) == true);
REQUIRE(checkTwoValuesIdentical(1,2,1,4) == true);
REQUIRE(checkTwoValuesIdentical(1,2,3,2) == true);
}
}
TEST_CASE( "swap_ab_fafb_if_not_ascending", "[TOMS748]" )
{
std::cout.precision(16);
SECTION( "ascending (do nothing)" ) {
double a = 1., b = 2.,
fa = f(a), fb = f(b);
swap_ab_fafb_if_not_ascending(a,b,fa,fb);
REQUIRE(a == 1.);
REQUIRE(b == 2.);
REQUIRE(fa == f(1.));
REQUIRE(fb == f(2.));
}
SECTION( "descending (swap)" ) {
double a = 2., b = 1.,
fa = f(a), fb = f(b);
swap_ab_fafb_if_not_ascending(a,b,fa,fb);
REQUIRE(a == 1.);
REQUIRE(b == 2.);
REQUIRE(fa == f(1.));
REQUIRE(fb == f(2.));
}
}
TEST_CASE( "Bracket", "[TOMS748]" )
{
std::cout.precision(16);
double a = 0, b = 1, c = 0.3, d = 1.2;
double fa = f(a), fb = f(b), fc = f(c), fd = f(d);
double tol = 1e-12;
bool ok;
SECTION( "checkTwoValuesIdentical" ) {
REQUIRE(check_almost_equal(fbracket1(a, b, fa, fb), 0.5792645075960517, tol));
REQUIRE(check_almost_equal(fbracket2(a, b, d, fa, fb, fd), 0.1619293662546865, tol));
}
SECTION( "bracket normal operation [c b]" ) {
std::tie(a, b, d, fa, fb, fd) = bracket(a, b, c, fa, fb, fc);
REQUIRE(check_almost_equal(a, 0.3, tol));
REQUIRE(check_almost_equal(b, 1., tol));
REQUIRE(check_almost_equal(d, 0., tol));
REQUIRE(check_almost_equal(fa, f(0.3), tol));
REQUIRE(check_almost_equal(fb, f(1.), tol));
REQUIRE(check_almost_equal(fd, f(0.), tol));
}
SECTION( "bracket normal operation [a c]" ) {
std::tie(a, b, d, fa, fb, fd) = bracket(a, b, 0.7, fa, fb, f(0.7));
REQUIRE(check_almost_equal(a, 0., tol));
REQUIRE(check_almost_equal(b, 0.7, tol));
REQUIRE(check_almost_equal(d, 1., tol));
REQUIRE(check_almost_equal(fa, f(0.), tol));
REQUIRE(check_almost_equal(fb, f(0.7), tol));
REQUIRE(check_almost_equal(fd, f(1.), tol));
}
SECTION( "bracket and check normal operation [c b]" ) {
std::tie(a, b, d, fa, fb, fd, ok) = bracketAndCheckConvergence(a, b, c, fa, fb, fc, tol);
REQUIRE(check_almost_equal(a, 0.3, tol));
REQUIRE(check_almost_equal(b, 1., tol));
REQUIRE(check_almost_equal(d, 0., tol));
REQUIRE(check_almost_equal(fa, f(0.3), tol));
REQUIRE(check_almost_equal(fb, f(1.), tol));
REQUIRE(check_almost_equal(fd, f(0.), tol));
REQUIRE_FALSE(ok);
}
SECTION( "bracket and check normal operation [a c]" ) {
std::tie(a, b, d, fa, fb, fd, ok) = bracketAndCheckConvergence(a, b, 0.7, fa, fb, f(0.7), tol);
REQUIRE(check_almost_equal(a, 0., tol));
REQUIRE(check_almost_equal(b, 0.7, tol));
REQUIRE(check_almost_equal(d, 1., tol));
REQUIRE(check_almost_equal(fa, f(0.), tol));
REQUIRE(check_almost_equal(fb, f(0.7), tol));
REQUIRE(check_almost_equal(fd, f(1.), tol));
REQUIRE_FALSE(ok);
}
SECTION( "bracket and check normal operation c is root" ) {
std::tie(a, b, d, fa, fb, fd, ok) = bracketAndCheckConvergence(a, b, 0.569682256443945, fa, fb, f(0.569682256443945), tol);
REQUIRE(check_almost_equal(a, 0., tol));
REQUIRE(check_almost_equal(b, 0.569682256443945, tol));
REQUIRE(check_almost_equal(d, 1., tol));
REQUIRE(check_almost_equal(fa, f(0.), tol));
REQUIRE(check_almost_equal(fb, f(0.569682256443945), tol));
REQUIRE(check_almost_equal(fd, f(1.), tol));
REQUIRE(ok);
}
}
TEST_CASE( "TOMS748_solve1", "[TOMS748]")
{
double a = 0, b = 1, tol = 1e-12;
SECTION( "TOMS748_solve1 good" ) {
auto [x, fx, ok] = TOMS748_solve1(f<double>, a, b, tol);
CHECK(check_almost_equal(fx, 0., tol));
CHECK(ok == true);
}
SECTION( "TOMS748_solve1 wrong initial bracket" ) {
auto [x, fx, ok] = TOMS748_solve1(f<double>, 1., 2., tol);// still converges !
CHECK(check_almost_equal(fx, 0., tol));
CHECK(ok == true);
}
}
TEST_CASE( "findEnclosingBracket", "[TOMS748]")
{
double tol = 1e-12, factor = 2;
SECTION( "findEnclosingBracket start at 0" ) {
double x = 0;
auto [a, b, fa, fb, ok] = findEnclosingBracket(f2<double>, x, factor);
CHECK(a == x);
CHECK(a == b);
CHECK_FALSE(fa*fb <= 0.);
CHECK_FALSE(ok);
}
SECTION( "findEnclosingBracket sol below" ) {
double x = 2.;
auto [a, b, fa, fb, ok] = findEnclosingBracket(f2<double>, x, factor);
CHECK(check_almost_equal(b, x, tol));
CHECK(a < b);
CHECK(fa*fb <= 0.);
CHECK(ok);
}
SECTION( "findEnclosingBracket sol above" ) {
double x = .1;
auto [a, b, fa, fb, ok] = findEnclosingBracket(f2<double>, x, factor);
CHECK(check_almost_equal(a, x, tol));
CHECK(a < b);
CHECK(fa*fb <= 0.);
CHECK(ok);
}
SECTION( "findEnclosingBracket initial bracket good" ) {
double x = .5;
auto [a, b, fa, fb, ok] = findEnclosingBracket(f2<double>, x, factor);
CHECK(check_almost_equal(a, x, tol));
CHECK(a < b);
CHECK(fa*fb <= 0.);
CHECK(ok);
}
SECTION( "findEnclosingBracket sol above far far away" ) {
double x = 1e-10;
auto [a, b, fa, fb, ok] = findEnclosingBracket(f2<double>, x, factor);
CHECK(check_almost_equal(a, x, tol));
CHECK(a < b);
CHECK(fa*fb <= 0.);
CHECK(ok);
}
SECTION( "findEnclosingBracket sol on the other side of 0 (does not converge)" ) {
double x = -1.;
auto [a, b, fa, fb, ok] = findEnclosingBracket(f2<double>, x, factor);
CHECK(a < b);
CHECK_FALSE(fa*fb <= 0.);
CHECK_FALSE(ok);
}
SECTION( "findEnclosingBracket sol below (negative)" ) {
double x = -.1;
auto [a, b, fa, fb, ok] = findEnclosingBracket(f3<double>, x, factor);
CHECK(check_almost_equal(b, x, tol));
CHECK(a < b);
CHECK(fa*fb <= 0.);
CHECK(ok);
}
SECTION( "findEnclosingBracket sol above (negative)" ) {
double x = -2.;
auto [a, b, fa, fb, ok] = findEnclosingBracket(f3<double>, x, factor);
CHECK(check_almost_equal(a, x, tol));
CHECK(a < b);
CHECK(fa*fb <= 0.);
CHECK(ok);
}
}
TEST_CASE( "bracket_and_solve", "[TOMS748]" )
{
double tol = 1e-12;
SECTION( "initial guess close to solution" ) {
auto [x, fx, ok] = bracket_and_solve(f<double>, 0.5, tol);
CHECK(check_almost_equal(fx, 0., tol));
CHECK(ok == true);
}
SECTION( "initial guess far from solution" ) {
auto [x, fx, ok] = bracket_and_solve(f<double>, 100., tol);
CHECK(check_almost_equal(fx, 0., tol));
CHECK(ok == true);
}
SECTION( "initial guess on other side of 0 compared to solution" ) {
auto [x, fx, ok] = bracket_and_solve(f<double>, -1., tol);
CHECK_FALSE(check_almost_equal(fx, 0., tol));
CHECK_FALSE(ok);
}
}

View file

@ -3,7 +3,7 @@ C = gcc
COMMON_FLAGS = -Wall -MMD -fprofile-arcs -ftest-coverage
C_FLAGS = $(COMMON_FLAGS)
CC = g++
CC_FLAGS = $(COMMON_FLAGS) -std=c++17 -O0
CC_FLAGS = $(COMMON_FLAGS) -std=c++17 -O0 -g
LD_FLAGS = -lgcov
INCLUDES =