diff --git a/src/main.rs b/src/main.rs index f10bd10..f8974c9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,6 +19,16 @@ fn dfct(x : f64) -> f64 { } } +fn ddfct(x : f64) -> f64 { + // exp(x) - 2*cos(x)/x**2 - (x**2 - 2)*sin(x)/x**3 + if x != 0.0 { + let tmp0 : f64 = f64::powi(x, 2); + return (f64::exp(x) - (tmp0 - 2.0)*f64::sin(x)/f64::powi(x, 3) - 2.0*f64::cos(x)/tmp0) as f64; + } else { + return 2.0/3.0; + } +} + fn main() { println!("Testing Rust numerical solvers."); let x0 : f64 = 1.0; @@ -28,11 +38,15 @@ fn main() { let x_mathematica = -3.26650043678562449167148755288; let x_newton = univariate_solvers::newton_solve(&(fct as fn(f64) -> f64), &(dfct as fn(f64) -> f64), x0, tol, max_iter); let x_newton_num: f64 = univariate_solvers::newton_solve_num(&(fct as fn(f64) -> f64), x0, tol, dx_num, max_iter); + let x_halley: f64 = univariate_solvers::halley_solve(&(fct as fn(f64) -> f64), &(dfct as fn(f64) -> f64), &(ddfct as fn(f64) -> f64), x0, tol, max_iter, false).unwrap(); let x_bisection : f64 = univariate_solvers::bisection_solve(&(fct as fn(f64) -> f64), -5.0, 1.0, tol).unwrap(); let x_secant : f64 = univariate_solvers::secant_solve(&(fct as fn(f64) -> f64), -1.0, 1.0, tol, max_iter); + let x_ridder : f64 = univariate_solvers::ridder_solve(&(fct as fn(f64) -> f64), -5.0, 1.0, tol, max_iter).unwrap(); println!("Mathematica : x = {}\tf(x) = {}", x_mathematica, fct(x_mathematica)); println!("Newton's method : x = {}\tf(x) = {}", x_newton, fct(x_newton)); println!("Newton's method (num) : x = {}\tf(x) = {}", x_newton_num, fct(x_newton_num)); + println!("Halley's method : x = {}\tf(x) = {}", x_halley, fct(x_halley)); println!("Bisection : x = {}\tf(x) = {}", x_bisection, fct(x_bisection)); println!("Secant method : x = {}\tf(x) = {}", x_secant, fct(x_secant)); + println!("Ridder's method : x = {}\tf(x) = {}", x_ridder, fct(x_ridder)); } diff --git a/src/univariate_solvers.rs b/src/univariate_solvers.rs index 91a97c1..69bc98d 100644 --- a/src/univariate_solvers.rs +++ b/src/univariate_solvers.rs @@ -43,6 +43,37 @@ where F : Fn(f64) -> f64 }, x0, tol, max_iter); } +/// Halley's method for solving a function f(x) = 0 +/// @param f function to solve +/// @param df derivative of function f +/// @param ddf second derivative of function f +/// @param x0 initial guess +/// @param tol tolerance +/// @param max_iter maximum number of iterations +/// @return solution +/// @note This method is more efficient than Newton's method, but requires the second derivative of f +pub fn halley_solve(f: F, df: F2, ddf: F3, x0: f64, tol: f64, max_iter: u32, verbose: bool) -> Result +where F : Fn(f64) -> f64, F2 : Fn(f64) -> f64, F3 : Fn(f64) -> f64 +{ + let mut x: f64 = x0; + let mut f_x: f64; + let mut df_x: f64; + let mut ddf_x: f64; + for _i in 0..max_iter { + f_x = f(x); + df_x = df(x); + ddf_x = ddf(x); + if verbose { + println!("x = {}, f(x) = {}, df(x) = {}, ddf(x) = {}", x, f_x, df_x, ddf_x); + } + if f64::abs(f_x) < tol { + return Ok(x); + } + x = x - 2.0*f_x*df_x / (2.0*df_x.powi(2) - f_x*ddf_x); + } + return Err("Halley method did not converge after reaching the maximum number of iterations allowed.") +} + // -------------------------------------------------------------------- // ------------------------ Bracketing methods ------------------------ // -------------------------------------------------------------------- @@ -107,4 +138,55 @@ where F : Fn(f64) -> f64 } } return c; +} + +/// @brief Ridder's method for solving a function f(x) = 0 +/// @param f function to solve +/// @param a left bracket +/// @param b right bracket +/// @param tol tolerance +/// @return solution +/// @note The interval [a, b] must bracket the root, meaning f(a) and f(b) must be of a different sign. +pub fn ridder_solve(f : F, mut a : f64, mut b : f64, tol : f64, max_iter : u32) -> Result +where F : Fn(f64) -> f64 +{ + let mut fa: f64 = f(a); + let mut fb: f64 = f(b); + let mut c: f64; + let mut fc: f64; + let mut s: f64; + let mut dx: f64; + let mut x: f64; + let mut fx: f64; + let mut x_old: f64 = (a + b)/2.0; + if fa == 0.0 { return Ok(a); } + if fb == 0.0 { return Ok(b); } + if fa*fb > 0.0 { + return Err("Root is not bracketed") + } + for i in 0..max_iter { + // Compute the improved root x from Ridder's formula + c = 0.5*(a + b); fc = f(c); + s = f64::sqrt(fc.powi(2) - fa*fb); + if s != 0.0 { + dx = (c - a)*fc/s; + } else { + dx = (c - a)*fc; + } + if (fa - fb) < 0.0 { dx = -dx; } + x = c + dx; fx = f(x); + // Test for convergence + if i > 0 { + if f64::abs(x - x_old) < tol*f64::max(f64::abs(x),1.0) { return Ok(x) } + } + x_old = x; + // Re-bracket the root as tightly as possible + if fc*fx > 0.0 { + if fa*fx < 0.0 { b = x; fb = fx; } + else { a = x; fa = fx; } + } else { + a = c; b = x; fa = fc; fb = fx; + } + } + return Err("Maximum number of iterations exceeded.") } \ No newline at end of file