Added Halley's method.

This commit is contained in:
Jérôme 2023-03-26 00:08:03 +01:00
parent bc11ee8616
commit 176168924a
2 changed files with 96 additions and 0 deletions

View file

@ -19,6 +19,16 @@ fn dfct(x : f64) -> f64 {
} }
} }
fn ddfct(x : f64) -> f64 {
// exp(x) - 2*cos(x)/x**2 - (x**2 - 2)*sin(x)/x**3
if x != 0.0 {
let tmp0 : f64 = f64::powi(x, 2);
return (f64::exp(x) - (tmp0 - 2.0)*f64::sin(x)/f64::powi(x, 3) - 2.0*f64::cos(x)/tmp0) as f64;
} else {
return 2.0/3.0;
}
}
fn main() { fn main() {
println!("Testing Rust numerical solvers."); println!("Testing Rust numerical solvers.");
let x0 : f64 = 1.0; let x0 : f64 = 1.0;
@ -28,11 +38,15 @@ fn main() {
let x_mathematica = -3.26650043678562449167148755288; let x_mathematica = -3.26650043678562449167148755288;
let x_newton = univariate_solvers::newton_solve(&(fct as fn(f64) -> f64), &(dfct as fn(f64) -> f64), x0, tol, max_iter); let x_newton = univariate_solvers::newton_solve(&(fct as fn(f64) -> f64), &(dfct as fn(f64) -> f64), x0, tol, max_iter);
let x_newton_num: f64 = univariate_solvers::newton_solve_num(&(fct as fn(f64) -> f64), x0, tol, dx_num, max_iter); let x_newton_num: f64 = univariate_solvers::newton_solve_num(&(fct as fn(f64) -> f64), x0, tol, dx_num, max_iter);
let x_halley: f64 = univariate_solvers::halley_solve(&(fct as fn(f64) -> f64), &(dfct as fn(f64) -> f64), &(ddfct as fn(f64) -> f64), x0, tol, max_iter, false).unwrap();
let x_bisection : f64 = univariate_solvers::bisection_solve(&(fct as fn(f64) -> f64), -5.0, 1.0, tol).unwrap(); let x_bisection : f64 = univariate_solvers::bisection_solve(&(fct as fn(f64) -> f64), -5.0, 1.0, tol).unwrap();
let x_secant : f64 = univariate_solvers::secant_solve(&(fct as fn(f64) -> f64), -1.0, 1.0, tol, max_iter); let x_secant : f64 = univariate_solvers::secant_solve(&(fct as fn(f64) -> f64), -1.0, 1.0, tol, max_iter);
let x_ridder : f64 = univariate_solvers::ridder_solve(&(fct as fn(f64) -> f64), -5.0, 1.0, tol, max_iter).unwrap();
println!("Mathematica : x = {}\tf(x) = {}", x_mathematica, fct(x_mathematica)); println!("Mathematica : x = {}\tf(x) = {}", x_mathematica, fct(x_mathematica));
println!("Newton's method : x = {}\tf(x) = {}", x_newton, fct(x_newton)); println!("Newton's method : x = {}\tf(x) = {}", x_newton, fct(x_newton));
println!("Newton's method (num) : x = {}\tf(x) = {}", x_newton_num, fct(x_newton_num)); println!("Newton's method (num) : x = {}\tf(x) = {}", x_newton_num, fct(x_newton_num));
println!("Halley's method : x = {}\tf(x) = {}", x_halley, fct(x_halley));
println!("Bisection : x = {}\tf(x) = {}", x_bisection, fct(x_bisection)); println!("Bisection : x = {}\tf(x) = {}", x_bisection, fct(x_bisection));
println!("Secant method : x = {}\tf(x) = {}", x_secant, fct(x_secant)); println!("Secant method : x = {}\tf(x) = {}", x_secant, fct(x_secant));
println!("Ridder's method : x = {}\tf(x) = {}", x_ridder, fct(x_ridder));
} }

View file

@ -43,6 +43,37 @@ where F : Fn(f64) -> f64
}, x0, tol, max_iter); }, x0, tol, max_iter);
} }
/// Halley's method for solving a function f(x) = 0
/// @param f function to solve
/// @param df derivative of function f
/// @param ddf second derivative of function f
/// @param x0 initial guess
/// @param tol tolerance
/// @param max_iter maximum number of iterations
/// @return solution
/// @note This method is more efficient than Newton's method, but requires the second derivative of f
pub fn halley_solve<F, F2, F3>(f: F, df: F2, ddf: F3, x0: f64, tol: f64, max_iter: u32, verbose: bool) -> Result<f64, &'static str>
where F : Fn(f64) -> f64, F2 : Fn(f64) -> f64, F3 : Fn(f64) -> f64
{
let mut x: f64 = x0;
let mut f_x: f64;
let mut df_x: f64;
let mut ddf_x: f64;
for _i in 0..max_iter {
f_x = f(x);
df_x = df(x);
ddf_x = ddf(x);
if verbose {
println!("x = {}, f(x) = {}, df(x) = {}, ddf(x) = {}", x, f_x, df_x, ddf_x);
}
if f64::abs(f_x) < tol {
return Ok(x);
}
x = x - 2.0*f_x*df_x / (2.0*df_x.powi(2) - f_x*ddf_x);
}
return Err("Halley method did not converge after reaching the maximum number of iterations allowed.")
}
// -------------------------------------------------------------------- // --------------------------------------------------------------------
// ------------------------ Bracketing methods ------------------------ // ------------------------ Bracketing methods ------------------------
// -------------------------------------------------------------------- // --------------------------------------------------------------------
@ -108,3 +139,54 @@ where F : Fn(f64) -> f64
} }
return c; return c;
} }
/// @brief Ridder's method for solving a function f(x) = 0
/// @param f function to solve
/// @param a left bracket
/// @param b right bracket
/// @param tol tolerance
/// @return solution
/// @note The interval [a, b] must bracket the root, meaning f(a) and f(b) must be of a different sign.
pub fn ridder_solve<F>(f : F, mut a : f64, mut b : f64, tol : f64, max_iter : u32) -> Result<f64, &'static str>
where F : Fn(f64) -> f64
{
let mut fa: f64 = f(a);
let mut fb: f64 = f(b);
let mut c: f64;
let mut fc: f64;
let mut s: f64;
let mut dx: f64;
let mut x: f64;
let mut fx: f64;
let mut x_old: f64 = (a + b)/2.0;
if fa == 0.0 { return Ok(a); }
if fb == 0.0 { return Ok(b); }
if fa*fb > 0.0 {
return Err("Root is not bracketed")
}
for i in 0..max_iter {
// Compute the improved root x from Ridder's formula
c = 0.5*(a + b); fc = f(c);
s = f64::sqrt(fc.powi(2) - fa*fb);
if s != 0.0 {
dx = (c - a)*fc/s;
} else {
dx = (c - a)*fc;
}
if (fa - fb) < 0.0 { dx = -dx; }
x = c + dx; fx = f(x);
// Test for convergence
if i > 0 {
if f64::abs(x - x_old) < tol*f64::max(f64::abs(x),1.0) { return Ok(x) }
}
x_old = x;
// Re-bracket the root as tightly as possible
if fc*fx > 0.0 {
if fa*fx < 0.0 { b = x; fb = fx; }
else { a = x; fa = fx; }
} else {
a = c; b = x; fa = fc; fb = fx;
}
}
return Err("Maximum number of iterations exceeded.")
}