From c3f5f423338f2c3d34b56288a366140f5e23fa5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me?= Date: Thu, 4 Apr 2019 10:36:27 +0200 Subject: [PATCH] Auto test of TOMS748. Working version of 4.1. Auto-Bracketing done. Bracket and solve routine done. --- .gitignore | 2 + cpp/Makefile | 2 +- cpp/TOMS748.hpp | 154 ++++++++++++++++++++++++--- cpp/test/1_test.cpp | 254 ++++++++++++++++++++++++++++++++++++++++---- cpp/test/Makefile | 2 +- 5 files changed, 379 insertions(+), 35 deletions(-) diff --git a/.gitignore b/.gitignore index 2860229..4f3fd03 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ # executable run +cpp/test/run +cpp/run *.pdf diff --git a/cpp/Makefile b/cpp/Makefile index 911d554..3dc38f0 100644 --- a/cpp/Makefile +++ b/cpp/Makefile @@ -3,7 +3,7 @@ C = clang COMMON_FLAGS = -Wall -MMD C_FLAGS = $(COMMON_FLAGS) CC = clang++ -CC_FLAGS = $(COMMON_FLAGS) -std=c++17 +CC_FLAGS = $(COMMON_FLAGS) -std=c++17 -O3 LD_FLAGS = INCLUDES = diff --git a/cpp/TOMS748.hpp b/cpp/TOMS748.hpp index b9832ec..9e5fdce 100644 --- a/cpp/TOMS748.hpp +++ b/cpp/TOMS748.hpp @@ -7,9 +7,12 @@ namespace TOMS748 { namespace internal { + template T sureAbs(const T & x) { return (x < T(0)) ? -x : x; } + // checks that none of the values are the same // returns true if at least two values are identical - template bool checkTwoValuesIdentical(T fa, T fb, T fc, T fd) + template + bool checkTwoValuesIdentical(T fa, T fb, T fc, T fd) { bool same = false; same |= fa == fb; @@ -32,7 +35,8 @@ namespace TOMS748 // standard bracketing routine // returns ahat, bhat, d // d is a point outside the new interval - template std::tuple bracket(T a, T b, T c, T fa, T fb, T fc) + template std::tuple + bracket(T a, T b, T c, T fa, T fb, T fc) { if(fa*fc < T(0)) return std::make_tuple(a, c, b, fa, fc, fb); @@ -44,10 +48,11 @@ namespace TOMS748 // checks if f(c) is close enough to 0 and returns ok = true if it is the case // returns ahat, bhat, d // d is a point outside the new interval - template std::tuple bracketAndCheckConvergence(T a, T b, T c, T fa, T fb, T fc, T tol) + template std::tuple + bracketAndCheckConvergence(T a, T b, T c, T fa, T fb, T fc, T tol) { bool ok = false; - if(sureAbs(fc) < tol) + if(internal::sureAbs(fc) < tol) ok = true; if(fa*fc < T(0)) return std::make_tuple(a, c, b, fa, fc, fb, ok); @@ -57,7 +62,8 @@ namespace TOMS748 // finds an approximate solution to the quadratic P(x) = fa + f[a,b]*(x-a) + f[a,b,d]*(x-a)(x-b) // with f[a,b] = fbracket1(a,b) and f[a,b,d] = fbracket2(a,b,d) - template T NewtonQuadratic(T a, T b, T d, T fa, T fb, T fd) + template + T NewtonQuadratic(T a, T b, T d, T fa, T fb, T fd) { T r; T A = fbracket2(a,b,d,fa,fb,fd); @@ -75,7 +81,8 @@ namespace TOMS748 } // Inverse cubic interpolation evaluated at 0 (modified Aitken-Neville interpolation) - template T ipzero(T a, T b, T c, T d, T fa, T fb, T fc, T fd) + template + T ipzero(T a, T b, T c, T d, T fa, T fb, T fc, T fd) { T Q11 = (c-d)*fc/(fd-fc); T Q21 = (b-c)*fb/(fc-fb); @@ -88,37 +95,75 @@ namespace TOMS748 T Q33 = (D32-Q22)*fa/(fd-fa); return a + Q31 + Q32 + Q33; } + + /// Sorts the pairs (abs(a), fa) and (abs(b), fb) in ascending order. + template + void swap_abs_ab_fafb_if_not_ascending(T & a, T & b, T & fa, T & fb) + { + if(internal::sureAbs(b) < internal::sureAbs(a)) + { + // swap a and b + T temp = a; + a = b; + b = temp; + // swap fa and fb + temp = fa; + fa = fb; + fb = temp; + } + } + + /// Sorts the pairs (a, fa) and (b, fb) in ascending order. + template + void swap_ab_fafb_if_not_ascending(T & a, T & b, T & fa, T & fb) + { + if(b < a) + { + // swap a and b + T temp = a; + a = b; + b = temp; + // swap fa and fb + temp = fa; + fa = fb; + fb = temp; + } + } }// namespace internal /// Algorithm 4.1 from TOMS748 of robust root-solving. /// Use this version if f(a) and f(b) have already been computed. - template std::tuple TOMS748_solve1(Func f, T a, T b, T fa, T fb, T tol, int Nmax = 1000) + /// Returns x, f(x), and a boolean indicating if the function converged or not (true if converged). + template + std::tuple TOMS748_solve1(Func f, T a, T b, T fa, T fb, T tol, unsigned int Nmax = 1000) { using namespace internal; T c, d, e, u, dbar, dhat, fc, fd, fe, fu, fdbar, fdhat; T mu = 0.5; bool ok; - c = a - fa/fbracket1(a, b, fa, fb); // 4.1.1 secant method + c = a - fa/fbracket1(a, b, fa, fb); // 4.1.1 secant method fc = f(c); std::tie(a, b, d, fa, fb, fd, ok) = bracketAndCheckConvergence(a, b, c, fa, fb, fc, tol); // 4.1.2 if(ok) { return std::make_tuple(c, fc, true); } e = d; fe = fd; // --- - for(int n = 2 ; n < Nmax ; n++) // 4.1.3 + for(unsigned int n = 2 ; n < Nmax ; n++) // 4.1.3 { if(n == 2 || checkTwoValuesIdentical(fa, fb, fd, fe)) c = NewtonQuadratic(a, b, d, fa, fb, fd); else + { c = ipzero(a, b, d, e, fa, fb, fd, fe); if((c-a)*(c-b) >= T(0)) c = NewtonQuadratic(a, b, d, fa, fb, fd); + } // --- fc = f(c); std::tie(a, b, dbar, fa, fb, fdbar, ok) = bracketAndCheckConvergence(a, b, c, fa, fb, fc, tol); // 4.1.4 if(ok) { return std::make_tuple(c, fc, true); } // --- - if(fabs(fa) < fabs(fb)) // 4.1.5 + if(internal::sureAbs(fa) < internal::sureAbs(fb)) // 4.1.5 { u = a; fu = fa; @@ -129,16 +174,16 @@ namespace TOMS748 fu = fb; } // --- - c = u - 2*fu/fbracket1(a, b, fa, fb); // 4.1.6 + c = u - 2*fu/fbracket1(a, b, fa, fb); // 4.1.6 // --- - if(fabs(c - u) > 0.5*(b - a)) // 4.1.7 + if(internal::sureAbs(c - u) > 0.5*(b - a)) // 4.1.7 c = 0.5*(b + a); // --- fc = f(c); std::tie(a, b, dhat, fa, fb, fdhat, ok) = bracketAndCheckConvergence(a, b, c, fa, fb, fc, tol); // 4.1.8 if(ok) { return std::make_tuple(c, fc, true); } // --- - if(b - a < mu*(b - a)) // 4.1.9 + if(b - a < mu*(b - a)) // 4.1.9 { d = dhat; e = dbar; @@ -158,10 +203,91 @@ namespace TOMS748 return std::make_tuple(c, fc, false);// no solution found, return last estimate } - template std::tuple TOMS748_solve1(Func f, T a, T b, T tol, int Nmax = 1000) + /// Algorithm 4.1 from TOMS748 of robust root-solving. + /// Use this version if f(a) and f(b) have NOT already been computed. + /// Returns x, f(x), and a boolean indicating if the function converged or not (true if converged). + template + std::tuple TOMS748_solve1(Func f, T a, T b, T tol, unsigned int Nmax = 1000) { return TOMS748_solve1(f, a, b, f(a), f(b), tol, Nmax); } + + /// Finds a bracket [a b] that encloses a root of f : f(a)*f(b) < 0, starting from the initial guess x. + /// A typical value for the factor is 2. + /// The initial guess must be on the correct side of the real line : the algorithm will never cross the zero during the search. + /// The factor is increased every 10 iterations in order to speed up convergence for when the root is orders of magnitude away from the initial guess. + /// If the bracket returned by the function is too wide, try reducing the factor (while always keeping it > 1). + /// + /// The function returns a tuple of the following values : + /// - T a : lower bound of the interval. + /// - T b : upper bound of the interval. + /// - T fa : function value at lower bound of the interval. + /// - T fb : function value at upper bound of the interval. + /// - bool ok : indicates whether the search was successful or not. + template + std::tuple findEnclosingBracket(Func f, T x, T factor, unsigned int Nmax = 100) + { + T a, b, fa, fb; + unsigned int i; + a = x; + b = x*factor; // try searching in the positive direction + fa = f(a); + fb = f(b); + internal::swap_abs_ab_fafb_if_not_ascending(a, b, fa, fb); + + if(fa*fb <= T(0)) // if the original bracket is already enclosing the solution + { + internal::swap_ab_fafb_if_not_ascending(a, b, fa, fb); + return std::make_tuple(a, b, fa, fb, true); + } + + if(internal::sureAbs(fa) < internal::sureAbs(fb)) // if fa is closer, search in the other direction + { + a = x/factor; + b = x; + fb = fa; // f(b) = f(x) + fa = f(a); + internal::swap_abs_ab_fafb_if_not_ascending(a, b, fa, fb); + } + i = 0; + while(i < Nmax && fa*fb > T(0)) + { + if(internal::sureAbs(fa) < internal::sureAbs(fb)) // fa is closer to 0 than fb, extend the interval in the a direction (reducing it) + { + a = a / factor; + fa = f(a); + } + else // fb is closer to 0 than fa, extend the interval in the b direction (augmenting it) + { + b = b * factor; + fb = f(b); + } + i = i+1; + + if(i % 10 == 0) // every 10 iterations, bump up the factor to speed up convergence + factor = factor * 5; + } + + internal::swap_ab_fafb_if_not_ascending(a, b, fa, fb); + return std::make_tuple(a, b, fa, fb, static_cast(fa*fb <= T(0))); + } + + /// Finds a bracket that encloses the solution and then calls TOMS748_solve1 to solve for the root of the function in the bracket. + /// Returns x, f(x), and a boolean indicating if the function converged or not (true if converged). + template + std::tuple bracket_and_solve(Func f, T x, T tol, T factor = T(2), unsigned int Nmax = 1000) + { + T a, b, fa, fb; + bool ok; + std::tie(a, b, fa, fb, ok) = findEnclosingBracket(f, x, factor, Nmax); + if(ok) + return TOMS748_solve1(f, a, b, fa, fb, tol, Nmax); + else + { + internal::swap_abs_ab_fafb_if_not_ascending(fa, fb, a, b); + return std::make_tuple(a, fa, false); + } + } }// namespace TOMS748 #endif diff --git a/cpp/test/1_test.cpp b/cpp/test/1_test.cpp index fa45c9b..f33fb1a 100644 --- a/cpp/test/1_test.cpp +++ b/cpp/test/1_test.cpp @@ -2,33 +2,249 @@ #include #include "../utils.hpp" +#include "../TOMS748.hpp" #include "utils_test.hpp" #define print(x) PRINT_VAR(x); #define printvec(x) PRINT_VEC(x); #define printstr(x) PRINT_STR(x); -TEST_CASE( "Test case 1", "[test1]" ) +template T f(T E) { return E - 0.5*sin(E) - 0.3; } +template T f2(T x) { return exp(x)-T(2); } +template T f3(T x) { return (x*x)-T(1); } + +using namespace TOMS748; +using namespace TOMS748::internal; + +TEST_CASE( "sureAbs", "[TOMS748]" ) +{ + REQUIRE(TOMS748::internal::sureAbs(-1) == 1); + REQUIRE(TOMS748::internal::sureAbs(-1.) == 1.); +} + +TEST_CASE( "checkTwoValuesIdentical", "[TOMS748]" ) { std::cout.precision(16); - SECTION( "Check almost equal" ) { - CHECK(check_almost_equal(1.00, 1.01, 0.1)); - CHECK(check_almost_equal(1.00, 3.01, 0.01)); - CHECK(check_almost_equal(1.00, 1.01, 0.001)); - REQUIRE(true); - } - - SECTION( "Check almost equal on vectors" ) { - unsigned int N = 5; - std::vector v1(N), v2(N); - for (size_t i = 0; i < N; i++) { - v1[i] = double(i); - v2[i] = v1[i] + 0.0001; - } - - CHECK(check_almost_equalV(v1, v2, 0.001)); - CHECK(check_almost_equalV(v1, v2, 0.0001)); - CHECK(check_almost_equalV(v1, v2, 0.00001)); + SECTION( "checkTwoValuesIdentical" ) { + REQUIRE(checkTwoValuesIdentical(1,2,3,4) == false); + REQUIRE(checkTwoValuesIdentical(1,2,3,1) == true); + REQUIRE(checkTwoValuesIdentical(1,1,3,4) == true); + REQUIRE(checkTwoValuesIdentical(1,2,2,4) == true); + REQUIRE(checkTwoValuesIdentical(1,2,3,3) == true); + REQUIRE(checkTwoValuesIdentical(1,2,1,4) == true); + REQUIRE(checkTwoValuesIdentical(1,2,3,2) == true); + } +} + +TEST_CASE( "swap_ab_fafb_if_not_ascending", "[TOMS748]" ) +{ + std::cout.precision(16); + + SECTION( "ascending (do nothing)" ) { + double a = 1., b = 2., + fa = f(a), fb = f(b); + + swap_ab_fafb_if_not_ascending(a,b,fa,fb); + + REQUIRE(a == 1.); + REQUIRE(b == 2.); + REQUIRE(fa == f(1.)); + REQUIRE(fb == f(2.)); + } + + SECTION( "descending (swap)" ) { + double a = 2., b = 1., + fa = f(a), fb = f(b); + + swap_ab_fafb_if_not_ascending(a,b,fa,fb); + + REQUIRE(a == 1.); + REQUIRE(b == 2.); + REQUIRE(fa == f(1.)); + REQUIRE(fb == f(2.)); + } +} + +TEST_CASE( "Bracket", "[TOMS748]" ) +{ + std::cout.precision(16); + + double a = 0, b = 1, c = 0.3, d = 1.2; + double fa = f(a), fb = f(b), fc = f(c), fd = f(d); + double tol = 1e-12; + bool ok; + + SECTION( "checkTwoValuesIdentical" ) { + REQUIRE(check_almost_equal(fbracket1(a, b, fa, fb), 0.5792645075960517, tol)); + REQUIRE(check_almost_equal(fbracket2(a, b, d, fa, fb, fd), 0.1619293662546865, tol)); + } + + SECTION( "bracket normal operation [c b]" ) { + std::tie(a, b, d, fa, fb, fd) = bracket(a, b, c, fa, fb, fc); + REQUIRE(check_almost_equal(a, 0.3, tol)); + REQUIRE(check_almost_equal(b, 1., tol)); + REQUIRE(check_almost_equal(d, 0., tol)); + REQUIRE(check_almost_equal(fa, f(0.3), tol)); + REQUIRE(check_almost_equal(fb, f(1.), tol)); + REQUIRE(check_almost_equal(fd, f(0.), tol)); + } + + SECTION( "bracket normal operation [a c]" ) { + std::tie(a, b, d, fa, fb, fd) = bracket(a, b, 0.7, fa, fb, f(0.7)); + REQUIRE(check_almost_equal(a, 0., tol)); + REQUIRE(check_almost_equal(b, 0.7, tol)); + REQUIRE(check_almost_equal(d, 1., tol)); + REQUIRE(check_almost_equal(fa, f(0.), tol)); + REQUIRE(check_almost_equal(fb, f(0.7), tol)); + REQUIRE(check_almost_equal(fd, f(1.), tol)); + } + + SECTION( "bracket and check normal operation [c b]" ) { + std::tie(a, b, d, fa, fb, fd, ok) = bracketAndCheckConvergence(a, b, c, fa, fb, fc, tol); + REQUIRE(check_almost_equal(a, 0.3, tol)); + REQUIRE(check_almost_equal(b, 1., tol)); + REQUIRE(check_almost_equal(d, 0., tol)); + REQUIRE(check_almost_equal(fa, f(0.3), tol)); + REQUIRE(check_almost_equal(fb, f(1.), tol)); + REQUIRE(check_almost_equal(fd, f(0.), tol)); + REQUIRE_FALSE(ok); + } + + SECTION( "bracket and check normal operation [a c]" ) { + std::tie(a, b, d, fa, fb, fd, ok) = bracketAndCheckConvergence(a, b, 0.7, fa, fb, f(0.7), tol); + REQUIRE(check_almost_equal(a, 0., tol)); + REQUIRE(check_almost_equal(b, 0.7, tol)); + REQUIRE(check_almost_equal(d, 1., tol)); + REQUIRE(check_almost_equal(fa, f(0.), tol)); + REQUIRE(check_almost_equal(fb, f(0.7), tol)); + REQUIRE(check_almost_equal(fd, f(1.), tol)); + REQUIRE_FALSE(ok); + } + + SECTION( "bracket and check normal operation c is root" ) { + std::tie(a, b, d, fa, fb, fd, ok) = bracketAndCheckConvergence(a, b, 0.569682256443945, fa, fb, f(0.569682256443945), tol); + REQUIRE(check_almost_equal(a, 0., tol)); + REQUIRE(check_almost_equal(b, 0.569682256443945, tol)); + REQUIRE(check_almost_equal(d, 1., tol)); + REQUIRE(check_almost_equal(fa, f(0.), tol)); + REQUIRE(check_almost_equal(fb, f(0.569682256443945), tol)); + REQUIRE(check_almost_equal(fd, f(1.), tol)); + REQUIRE(ok); + } +} + +TEST_CASE( "TOMS748_solve1", "[TOMS748]") +{ + double a = 0, b = 1, tol = 1e-12; + + SECTION( "TOMS748_solve1 good" ) { + auto [x, fx, ok] = TOMS748_solve1(f, a, b, tol); + CHECK(check_almost_equal(fx, 0., tol)); + CHECK(ok == true); + } + + SECTION( "TOMS748_solve1 wrong initial bracket" ) { + auto [x, fx, ok] = TOMS748_solve1(f, 1., 2., tol);// still converges ! + CHECK(check_almost_equal(fx, 0., tol)); + CHECK(ok == true); + } +} + +TEST_CASE( "findEnclosingBracket", "[TOMS748]") +{ + double tol = 1e-12, factor = 2; + + SECTION( "findEnclosingBracket start at 0" ) { + double x = 0; + auto [a, b, fa, fb, ok] = findEnclosingBracket(f2, x, factor); + CHECK(a == x); + CHECK(a == b); + CHECK_FALSE(fa*fb <= 0.); + CHECK_FALSE(ok); + } + + SECTION( "findEnclosingBracket sol below" ) { + double x = 2.; + auto [a, b, fa, fb, ok] = findEnclosingBracket(f2, x, factor); + CHECK(check_almost_equal(b, x, tol)); + CHECK(a < b); + CHECK(fa*fb <= 0.); + CHECK(ok); + } + + SECTION( "findEnclosingBracket sol above" ) { + double x = .1; + auto [a, b, fa, fb, ok] = findEnclosingBracket(f2, x, factor); + CHECK(check_almost_equal(a, x, tol)); + CHECK(a < b); + CHECK(fa*fb <= 0.); + CHECK(ok); + } + + SECTION( "findEnclosingBracket initial bracket good" ) { + double x = .5; + auto [a, b, fa, fb, ok] = findEnclosingBracket(f2, x, factor); + CHECK(check_almost_equal(a, x, tol)); + CHECK(a < b); + CHECK(fa*fb <= 0.); + CHECK(ok); + } + + SECTION( "findEnclosingBracket sol above far far away" ) { + double x = 1e-10; + auto [a, b, fa, fb, ok] = findEnclosingBracket(f2, x, factor); + CHECK(check_almost_equal(a, x, tol)); + CHECK(a < b); + CHECK(fa*fb <= 0.); + CHECK(ok); + } + + SECTION( "findEnclosingBracket sol on the other side of 0 (does not converge)" ) { + double x = -1.; + auto [a, b, fa, fb, ok] = findEnclosingBracket(f2, x, factor); + CHECK(a < b); + CHECK_FALSE(fa*fb <= 0.); + CHECK_FALSE(ok); + } + + SECTION( "findEnclosingBracket sol below (negative)" ) { + double x = -.1; + auto [a, b, fa, fb, ok] = findEnclosingBracket(f3, x, factor); + CHECK(check_almost_equal(b, x, tol)); + CHECK(a < b); + CHECK(fa*fb <= 0.); + CHECK(ok); + } + + SECTION( "findEnclosingBracket sol above (negative)" ) { + double x = -2.; + auto [a, b, fa, fb, ok] = findEnclosingBracket(f3, x, factor); + CHECK(check_almost_equal(a, x, tol)); + CHECK(a < b); + CHECK(fa*fb <= 0.); + CHECK(ok); + } +} + +TEST_CASE( "bracket_and_solve", "[TOMS748]" ) +{ + double tol = 1e-12; + SECTION( "initial guess close to solution" ) { + auto [x, fx, ok] = bracket_and_solve(f, 0.5, tol); + CHECK(check_almost_equal(fx, 0., tol)); + CHECK(ok == true); + } + + SECTION( "initial guess far from solution" ) { + auto [x, fx, ok] = bracket_and_solve(f, 100., tol); + CHECK(check_almost_equal(fx, 0., tol)); + CHECK(ok == true); + } + + SECTION( "initial guess on other side of 0 compared to solution" ) { + auto [x, fx, ok] = bracket_and_solve(f, -1., tol); + CHECK_FALSE(check_almost_equal(fx, 0., tol)); + CHECK_FALSE(ok); } } diff --git a/cpp/test/Makefile b/cpp/test/Makefile index 3a12f85..d1d90f8 100644 --- a/cpp/test/Makefile +++ b/cpp/test/Makefile @@ -3,7 +3,7 @@ C = gcc COMMON_FLAGS = -Wall -MMD -fprofile-arcs -ftest-coverage C_FLAGS = $(COMMON_FLAGS) CC = g++ -CC_FLAGS = $(COMMON_FLAGS) -std=c++17 -O0 +CC_FLAGS = $(COMMON_FLAGS) -std=c++17 -O0 -g LD_FLAGS = -lgcov INCLUDES =