我的模板,第一次实现,代码不够精简优化
1 #include <iostream>
2 #include <cstdio>
3 #include <cstring>
4
5 using namespace std;
6
7 template< int L, class T = int, class LT = long long >
8 class FFT
9 {
10 public :
11 FFT() {
12 n = -1;
13 }
14 void fft( T e[], int &m, int minL ) {
15 in( e, m, minL );
16 m = n;
17 fft();
18 out( e );
19 }
20 void ifft( T e[], int &m, int minL ) {
21 in( e, m, minL );
22 m = n;
23 ifft();
24 out( e );
25 }
26 T getP() {
27 return p;
28 }
29 private :
30 int isPrime( T x ) {
31 T i;
32 if ( x < 2 ) {
33 return 0;
34 }
35 /* overflow !! */
36 for ( i = 2; (LT)i*i <= x; ++i ) {
37 if ( x % i == 0 ) {
38 return 0;
39 }
40 }
41 return 1;
42 }
43 T powMod( T a, T b, T c ) {
44 T ans = 1;
45 a %= c;
46 while ( b > 0 ) {
47 if ( b & 1 ) {
48 ans = ( (LT)ans * a ) % c;
49 }
50 a = ( (LT)a * a ) % c;
51 b >>= 1;
52 }
53 return ans;
54 }
55 /* p is a prime number */
56 int isG( T g, T p ) {
57 T p0 = p - 1, i;
58 for ( i = 1; (LT)i*i <= p0; ++i ) {
59 if ( p0 % i == 0 ) {
60 if ( (powMod(g,i,p)==1) && (i<p0) ) {
61 return 0;
62 }
63 if ( (powMod(g,p0/i,p)==1) && (p0/i<p0) ) {
64 return 0;
65 }
66 }
67 }
68 return 1;
69 }
70 int rev_bit( int i ) {
71 int j = 0, k;
72 for ( k = 0; k < bit; ++k ) {
73 j = ( (j<<1)|(i&1) );
74 i >>= 1;
75 }
76 return j;
77 }
78 void reverse() {
79 int i, j;
80 T t;
81 for ( i = 0; i < n; ++i ) {
82 j = rev_bit( i );
83 if ( i < j ) {
84 t = a[ i ];
85 a[ i ] = a[ j ];
86 a[ j ] = t;
87 }
88 }
89 }
90 void in( T e[], int m, int minL ) {
91 int i, need = 0;
92 bit = 0;
93 while ( (1<<(++bit)) < minL )
94 ;
95 if ( n != (1<<bit) ) {
96 need = 1;
97 n = (1<<bit);
98 }
99 for ( i = 0; i < m; ++i ) {
100 a[ i ] = e[ i ];
101 }
102 for ( i = m; i < n; ++i ) {
103 a[ i ] = 0;
104 }
105 if ( need ) {
106 init( 21, 10000000 );
107 }
108 }
109 // lim2 >= bit
110 void init( int lim2, T minP ) {
111 T k = 2, ig = 2;
112 int i;
113 do {
114 ++k;
115 p = ( (k<<lim2) | 1 );
116 } while ( (p<minP) || (!isPrime(p)) );
117 while ( !isG(ig,p) ) {
118 ++ig;
119 }
120 for ( i = 0; i < bit; ++i ) {
121 g[ i ] = powMod( ig, (k<<(lim2-bit+i)), p );
122 }
123 }
124 void fft() {
125 T w, wm, u, t;
126 int s, m, m2, j, k;
127 reverse();
128 for ( s = bit-1; s >= 0; --s ) {
129 m2 = (1<<(bit-s));
130 m = (m2>>1);
131 wm = g[ s ];
132 for ( k = 0; k < n; k += m2 ) {
133 w = 1;
134 for ( j = 0; j < m; ++j ) {
135 t = ((LT)(w)) * a[k+j+m] % p;
136 u = a[ k + j ];
137 a[ k + j ] = ( u + t ) % p;
138 a[ k + j + m ] = ( u + p - t ) % p;
139 w = ( ((LT)w) * wm ) % p;
140 }
141 }
142 }
143 }
144 void ifft() {
145 T w, wm, u, t, inv;
146 int s, m, m2, j, k;
147 reverse();
148 for ( s = bit-1; s >= 0; --s ) {
149 m2 = (1<<(bit-s));
150 m = (m2>>1);
151 wm = powMod( g[s], p-2, p );
152 for ( k = 0; k < n; k += m2 ) {
153 w = 1;
154 for ( j = 0; j < m; ++j ) {
155 t = ((LT)(w)) * a[k+j+m] % p;
156 u = a[ k + j ];
157 a[ k + j ] = ( u + t ) % p;
158 a[ k + j + m ] = ( u + p - t ) % p;
159 w = ( ((LT)w) * wm ) % p;
160 }
161 }
162 }
163 inv = powMod( n, p-2, p );
164 for ( k = 0; k < n; ++k ) {
165 a[ k ] = ( ((LT)inv) * a[ k ] ) % p;
166 }
167 }
168 void out( T e[] ) {
169 int i;
170 for ( i = 0; i < n; ++i ) {
171 e[ i ] = a[ i ];
172 }
173 }
174
175 T a[ L ], g[ 100 ], p;
176 int n, bit;
177 };
178
179
180
181
182
183 #define L 140000
184 typedef long long Lint;
185
186 FFT< L, int, Lint > fft;
187 char s[ L ];
188
189 void bi_out( int a[] ) {
190 int i, n;
191 n = a[ 0 ];
192 for ( i = 0; i < n; ++i ) {
193 s[ i ] = '0' + a[ n - i ];
194 }
195 s[ n ] = 0;
196 puts( s );
197 }
198
199 int bi_in( int a[] ) {
200 int i, n;
201 if ( scanf( "%s", s ) != 1 ) {
202 return 0;
203 }
204 a[ 0 ] = n = strlen( s );
205 for ( i = 0; i < n; ++i ) {
206 a[ n - i ] = s[ i ] - '0';
207 }
208 return 1;
209 }
210
211 void bi_mul( int c[], int a[], int b[] ) {
212 int m, n, p, g, i;
213
214 n = ( (a[0]>b[0]) ? a[0] : b[0] );
215 n <<= 1;
216
217 m = a[ 0 ];
218 fft.fft( a+1, m, n );
219
220 m = b[ 0 ];
221 fft.fft( b+1, m, n );
222
223 p = fft.getP();
224
225 for ( i = 1; i <= m; ++i ) {
226 c[ i ] = (Lint)a[ i ] * b[ i ] % p;
227 }
228 fft.ifft( c+1, m, m );
229 g = 0;
230 for ( i = 1; i <= m; ++i ) {
231 g += c[ i ];
232 c[ i ] = g % 10;
233 g /= 10;
234 }
235 for ( i = a[0]+b[0]; (i>1)&&(c[i]==0); --i )
236 ;
237 c[ 0 ] = i;
238 }
239
240 int a[ L ], b[ L ], c[ L ];
241
242 int main() {
243 while ( bi_in( a ) && bi_in( b ) ) {
244 bi_mul( c, a, b );
245 bi_out( c );
246 }
247 return 0;
248 }
249