最近项目里总是要对很庞大的公式求导,很烦人,手工求导容易出错。
当然MATLAB是个好选择,不过当它要钱的时候,您可能就不这么认为了。
于是,实现了一个可以
编译期求导(不用担心运行时负担)的小型库,还不完全,仅支持多项式,sin,cos,pow,exp,log等函数求导。
后期的表达式优化做的不是很好。
下面是一些测试代码,完整的源码在
http://www.boostpro.com/vault/index.php?action=downloadfile&filename=[math]AD.zip实现部分很复杂,请多多指教。
只有1个函数, d(...)
支持高阶,多元求导。
d(exp, var)(value1, value2, ...)
exp内可以有多个变量,var表示要对其求导的变量,value表示求导以后用于计算表达式的变量的值。
比如:
d(d(x*x*x, x),x)(3.0) 表示对x*x*x求二阶导数在x=3.0时候的值。
d(d(x*x*y, x), y)(3.0, 4.0) 表示d(x*x*y)/(dxdy)在x=3.0,y=4.0的值。
d(d(x*x*x, x) +d(y*x, y), y) (2.0) 则表示 (d(x*x*x)/dx + d(y*x)/dy)/dy == 0。
可以直接用cout把求导后的表达式输出,不用给变量给值。
cout<<d(x*x, x) // 结果是:2*x
这里没有用任何迭代,是直接对表达式求导的。返回值是求导后的表达式,本质是一个仿函数。可以用boost::function保存起来使用。
例如:
boost::function<double (double)> df = d(pow(x, const_<10>::type()), x); //df 参数为1个double,返回double
然后就可以在任何地方使用 df 了:
double res = df(3.0) // res == pow(3, 9)
1
#include "ad.h"
2
#include <iostream>
3
#include <iterator>
4![](http://www.cppblog.com/Images/OutliningIndicators/None.gif)
5
using namespace std;
6![](http://www.cppblog.com/Images/OutliningIndicators/None.gif)
7
int main()
8![](http://www.cppblog.com/Images/OutliningIndicators/ExpandedBlockStart.gif)
![](http://www.cppblog.com/Images/OutliningIndicators/ContractedBlock.gif)
{
9
variable<0>::type x;
10
variable<1>::type y;
11![](http://www.cppblog.com/Images/OutliningIndicators/InBlock.gif)
12
double res[14];
13![](http://www.cppblog.com/Images/OutliningIndicators/InBlock.gif)
14
res[0] = d(pow(x, const_<10>::type()), x)(2.0);
15![](http://www.cppblog.com/Images/OutliningIndicators/InBlock.gif)
16
res[1] = d(x * x * x, x)(2.0);
17
res[2] = d(x + x + x, x)(2.0);
18
res[3] = d(x - x - x, x)(2.0);
19
res[4] = d(x / x, x)(2.0);
20![](http://www.cppblog.com/Images/OutliningIndicators/InBlock.gif)
21
res[5] = d(pow(x, var(3.0)), x)(2.0);
22
res[6] = d(pow(var(3.0), x), x)(2.0);
23
res[7] = d(pow(x, x), x)(2.0);
24![](http://www.cppblog.com/Images/OutliningIndicators/InBlock.gif)
25
res[8] = d(log(x), x)(2.0);
26
res[9] = d(exp(x), x)(2.0);
27![](http://www.cppblog.com/Images/OutliningIndicators/InBlock.gif)
28
res[10] = d(sin(x), x)(2.0);
29
res[11] = d(cos(x), x)(2.0);
30![](http://www.cppblog.com/Images/OutliningIndicators/InBlock.gif)
31
res[12] = d(d(sin(x) * cos(y), x), y)(2.0, 3.0);
32![](http://www.cppblog.com/Images/OutliningIndicators/InBlock.gif)
33
res[13] = (d(log(x) + x, x) * x)(2.0);
34![](http://www.cppblog.com/Images/OutliningIndicators/InBlock.gif)
35
copy(res, res + 14, ostream_iterator<double>(cout, "\n"));
36![](http://www.cppblog.com/Images/OutliningIndicators/InBlock.gif)
37
cout<<d(pow(x, const_<10>::type()), x)<<endl;
38![](http://www.cppblog.com/Images/OutliningIndicators/InBlock.gif)
39
cout<<d(x * x * x, x)<<endl;
40
cout<<d(x + x + x, x)<<endl;
41
cout<<d(x - x - x, x)<<endl;
42
cout<<d(x / x / x, x)<<endl;
43![](http://www.cppblog.com/Images/OutliningIndicators/InBlock.gif)
44
cout<<d(pow(x, var(3.0)), x)<<endl;
45
cout<<d(pow(var(3.0), x), x)<<endl;
46
cout<<d(pow(x, x), x)<<endl;
47![](http://www.cppblog.com/Images/OutliningIndicators/InBlock.gif)
48
cout<<d(log(x), x)<<endl;
49
cout<<d(exp(x), x)<<endl;
50![](http://www.cppblog.com/Images/OutliningIndicators/InBlock.gif)
51
cout<<d(sin(x), x)<<endl;
52
cout<<d(cos(x), x)<<endl;
53![](http://www.cppblog.com/Images/OutliningIndicators/InBlock.gif)
54
cout<<d(d(sin(x) * cos(y), x), y)<<endl;
55![](http://www.cppblog.com/Images/OutliningIndicators/InBlock.gif)
56
cout<<(d(log(x) + x, x) * x)<<endl;
57![](http://www.cppblog.com/Images/OutliningIndicators/InBlock.gif)
58
return 0;
59
}
60![](http://www.cppblog.com/Images/OutliningIndicators/None.gif)
输出结果如下:
1
512
2
12
3
3
4
-1
5
0
6
12
7
9.88751
8
6.77259
9
0.5
10
7.38906
11
-0.416147
12
0.909297
13
-0.0587266
14
3
15
pow(x,9)
16
(((x+x)*x)+(x*x))
17
3
18
-1
19
(-1/(x*x))
20
(pow(x,3)*(3*(1/x)))
21
(pow(3,x)*log(3))
22
(pow(x,x)*(log(x)+1))
23
(1/x)
24
exp(x)
25
cos(x)
26
sin(x)
27
(cos(x)*sin(y))
28
(((1/x)+1)*x)
29![](http://www.cppblog.com/Images/OutliningIndicators/None.gif)
posted on 2009-05-01 23:50
尹东斐 阅读(2574)
评论(6) 编辑 收藏 引用