Static derivative computation test added (boost::proto).

This commit is contained in:
Jérôme 2019-03-27 21:27:39 +01:00
parent 5c57aa0afd
commit 53a90f6346
4 changed files with 187 additions and 3 deletions

View file

@ -29,14 +29,14 @@ T ddf(const T & x)
int main()
{
double xdbl = 1.5;
{
cout << "Analytical\n";
cout << "f(x) = " << f(xdbl) << endl;
cout << "df(x)/dt = " << df(xdbl) << endl;
cout << "d²f(x)/dt = " << ddf(xdbl) << endl;
}
// 1st derivative forward
{
using Fd = Dual<double>;
@ -47,7 +47,7 @@ int main()
cout << "f(x) = " << y.a << endl;
cout << "df(x)/dt = " << y.d(0) << endl;
}
// 2nd derivative forward
/*
{

View file

@ -0,0 +1,33 @@
# Declaration of variables
C = clang
C_FLAGS = -Wall
CC = clang++
CC_FLAGS = -Wall -std=c++17 -O0
LD_FLAGS =
INCLUDES =
# File names
EXEC = run
CSOURCES = $(wildcard *.c)
COBJECTS = $(CSOURCES:.c=.o)
SOURCES = $(wildcard *.cpp)
OBJECTS = $(SOURCES:.cpp=.o)
# Main target
$(EXEC): $(COBJECTS) $(OBJECTS)
$(CC) $(LD_FLAGS) $(COBJECTS) $(OBJECTS) -o $(EXEC)
# To obtain object files
%.o: %.cpp
$(CC) $(INCLUDES) $(CC_FLAGS) -o $@ -c $<
# To obtain object files
%.o: %.c
$(C) $(INCLUDES) $(C_FLAGS) -o $@ -c $<
# To remove generated files
clean:
rm -f $(COBJECTS) $(OBJECTS)
cleaner:
rm -f $(EXEC) $(COBJECTS) $(OBJECTS)

View file

@ -0,0 +1,112 @@
#ifndef DEF_compile_time_derivative
#define DEF_compile_time_derivative
#include <boost/proto/proto.hpp>
using namespace boost::proto;
// Assuming derivative of one variable, the 'unknown'
struct unknown {};
// Boost.Proto calls this the expression wrapper
// elements of the EDSL will have this type
template<typename Expr>
struct expression;
// Boost.Proto calls this the domain
struct derived_domain
: domain<generator<expression>> {};
// We will use a context to evaluate expression templates
struct evaluation_context: callable_context<evaluation_context const> {
double value;
explicit evaluation_context(double value)
: value(value)
{}
typedef double result_type;
double operator()(tag::terminal, unknown) const
{ return value; }
};
// And now we can do:
// evaluation_context context(42);
// eval(expr, context);
// to evaluate an expression as though the unknown had value 42
template<typename Expr>
struct expression: extends<Expr, expression<Expr>, derived_domain> {
typedef extends<Expr, expression<Expr>, derived_domain> base_type;
expression(Expr const& expr = Expr())
: base_type(expr)
{}
typedef double result_type;
// We spare ourselves the need to write eval(expr, context)
// Instead, expr(42) is available
double operator()(double d) const
{
evaluation_context context(d);
return eval(*this, context);
}
};
// Boost.Proto calls this a transform -- we use this to operate
// on the expression templates
struct Derivative
: or_<
when<
terminal<unknown>
, boost::mpl::int_<1>()
>
, when<
terminal<_>
, boost::mpl::int_<0>()
>
, when<
plus<Derivative, Derivative>
, _make_plus(Derivative(_left), Derivative(_right))
>
, when<
minus<Derivative, Derivative>
, _make_minus(Derivative(_left), Derivative(_right))
>
, when<
multiplies<Derivative, Derivative>
, _make_plus(
_make_multiplies(Derivative(_left), _right)
, _make_multiplies(_left, Derivative(_right))
)
>
//*
, when<
divides<Derivative, Derivative>
, _make_divides
(
_make_minus
(
_make_multiplies
(
_right,
Derivative(_left)
),
_make_multiplies
(
_left,
Derivative(_right)
)
),
_make_multiplies
(
_right,
_right
)
)
>
, otherwise<_>
> {};
#endif

View file

@ -0,0 +1,39 @@
#include <iostream>
#include "utils.hpp"
#include "compile_time_derivative.hpp"
using std::cout;
using std::endl;
// x is the unknown
expression<terminal<unknown>::type> const x;
// A transform works as a functor
Derivative const derivative;
template<typename T>
T fct(const T & x) {
return (3)*x*x + (2)*x - (3) + (1)/x;
}
template<typename T>
T dfct(const T & x) {
return (6)*x + (2) - (1)/(x*x);
}
// The file must be compiled with -O0 because with -O2 and up
int main()
{
auto func = (3*x*x + 2*x - 3 + 1./x);
// auto func = fct(x);
auto dfunc = derivative(func);
cout.precision(16);
for(double i = 0.1 ; i < 2. ; i+=.1)
std::cout << func(i) << "\t" << fct(i) << '\n';
for(double i = 0.1 ; i < 2. ; i+=.1)
std::cout << func(i) << "\t" << fct(i) << "\t" << dfunc(i) << "\t" << dfct(i) << '\n';
return 0;
}