方阵相乘
Time Limit:1000MS Memory Limit:30000KB
Description
实现两个n*n 方阵相乘的Strassen 算法,这里假设 n 为 2 的方幂。
Input
第一行为一个正整数N,表示有几组测试数据。
每组测试数据的第一行为一个正整数n(1<=n<=100),n为2的方幂,表示方阵n*n
接下去的n行表示第一个方阵,每行有n个整数,用空格分开。
再接下去的n行表示第二个方阵,每行有n个整数,用空格分开。
Output
对于每组测试出据,输出n行,每行有n个整数,用空格分开,不能有多余的空格。
Sample Input
1
2
1 2
3 4
5 6
7 8
Sample Output
19 22
43 50
朴素的矩阵乘法
1#include <iostream>
2#include <cstdio>
3
4using namespace std;
5
6const int L = 103;
7
8int a[ L ][ L ], b[ L ][ L ], c[ L ][ L ];
9
10int main() {
11 int td, n, i, j, k, tmp;
12 scanf( "%d", &td );
13 while ( td-- ) {
14 scanf( "%d", &n );
15 for ( i = 0; i < n; ++i )
16 for ( j = 0; j < n; ++j )
17 scanf( "%d", &a[ i ][ j ] );
18 for ( i = 0; i < n; ++i )
19 for ( j = 0; j < n; ++j )
20 scanf( "%d", &b[ i ][ j ] );
21 for ( i = 0; i < n; ++i )
22 for ( j = 0; j < n; ++j ) {
23 tmp = 0;
24 for ( k = 0; k < n; ++k )
25 tmp += a[ i ][ k ] * b[ k ][ j ];
26 c[ i ][ j ] = tmp;
27 }
28 for ( i = 0; i < n; ++i ) {
29 printf( "%d", c[ i ][ 0 ] );
30 for ( j = 1; j < n; ++j )
31 printf( " %d", c[ i ][ j ] );
32 printf( "\n" );
33 }
34 }
35 return 0;
36}
37
Strassen 算法
1#include <iostream>
2#include <cstdio>
3
4using namespace std;
5
6#define L 102
7#define LIM 400
8
9typedef int Mat[ L ][ L ];
10
11Mat buf[ LIM ];
12int top;
13
14void input( int a[][L], int n ) {
15 int i, j;
16 for ( i = 1; i <= n; ++i ) {
17 for ( j = 1; j <= n; ++j ) {
18 scanf( "%d", &a[ i ][ j ] );
19 }
20 }
21}
22
23void output( int c[][L], int n ) {
24 int i, j;
25 for ( i = 1; i <= n; ++i ) {
26 for ( j = 1; j < n; ++j ) {
27 printf( "%d ", c[ i ][ j ] );
28 }
29 printf( "%d\n", c[ i ][ j ] );
30 }
31}
32
33void get( int a[][L], int a11[][L], int a12[][L], int a21[][L], int a22[][L], int n ) {
34 int i, j;
35 for ( i = 1; i <= n; ++i ) {
36 for ( j = 1; j <= n; ++j ) {
37 a11[ i ][ j ] = a[ i ][ j ];
38 a12[ i ][ j ] = a[ i ][ j + n ];
39 a21[ i ][ j ] = a[ i + n ][ j ];
40 a22[ i ][ j ] = a[ i + n ][ j + n ];
41 }
42 }
43}
44
45void put( int a[][L], int a11[][L], int a12[][L], int a21[][L], int a22[][L], int n ) {
46 int i, j;
47 for ( i = 1; i <= n; ++i ) {
48 for ( j = 1; j <= n; ++j ) {
49 a[ i ][ j ] = a11[ i ][ j ];
50 a[ i ][ j + n ] = a12[ i ][ j ];
51 a[ i + n ][ j ] = a21[ i ][ j ];
52 a[ i + n ][ j + n ] = a22[ i ][ j ];
53 }
54 }
55}
56
57void add( int c[][L], int a[][L], int b[][L], int n ) {
58 int i, j;
59 for ( i = 1; i <= n; ++i ) {
60 for ( j = 1; j <= n; ++j ) {
61 c[ i ][ j ] = a[ i ][ j ] + b[ i ][ j ];
62 }
63 }
64}
65
66void sub( int c[][L], int a[][L], int b[][L], int n ) {
67 int i, j;
68 for ( i = 1; i <= n; ++i ) {
69 for ( j = 1; j <= n; ++j ) {
70 c[ i ][ j ] = a[ i ][ j ] - b[ i ][ j ];
71 }
72 }
73}
74
75void mul( int c[][L], int a[][L], int b[][L], int n ) {
76#define ADD(m) Mat &m = buf[ top++ ]
77#define ADDS(a) ADD(a##11); ADD(a##12); ADD(a##21); ADD(a##22)
78#define ENTER ADDS(a); ADDS(b); ADDS(c); ADD(d1); ADD(d2); ADD(d3); ADD(d4); ADD(d5); ADD(d6); ADD(d7); ADD(t1); ADD(t2)
79#define LEAVE top -= 21
80
81 ENTER;
82
83 if ( top >= LIM ) {
84 // for debug
85 fprintf( stderr, "buf overflow!!" );
86 LEAVE;
87 return;
88 }
89
90
91 if ( n < 1 ) {
92 LEAVE;
93 return;
94 }
95 if ( n == 1 ) {
96 c[ 1 ][ 1 ] = a[ 1 ][ 1 ] * b[ 1 ][ 1 ];
97 LEAVE;
98 return;
99 }
100 n >>= 1;
101 get( a, a11, a12, a21, a22, n );
102 get( b, b11, b12, b21, b22, n );
103
104 add( t1, a11, a22, n );
105 add( t2, b11, b22, n );
106 mul( d1, t1, t2, n );
107
108 add( t1, a21, a22, n );
109 mul( d2, t1, b11, n );
110
111 sub( t2, b12, b22, n );
112 mul( d3, a11, t2, n );
113
114 sub( t2, b21, b11, n );
115 mul( d4, a22, t2, n );
116
117 add( t1, a11, a12, n );
118 mul( d5, t1, b22, n );
119
120 sub( t1, a21, a11, n );
121 add( t2, b11, b12, n );
122 mul( d6, t1, t2, n );
123
124 sub( t1, a12, a22, n );
125 add( t2, b21, b22, n );
126 mul( d7, t1, t2, n );
127
128 add( t1, d1, d4, n );
129 sub( t2, d5, d7, n );
130 sub( c11, t1, t2, n );
131
132 add( c12, d3, d5, n );
133
134 add( c21, d2, d4, n );
135
136 add( t1, d1, d3, n );
137 sub( t2, d2, d6, n );
138 sub( c22, t1, t2, n );
139
140 put( c, c11, c12, c21, c22, n );
141
142 LEAVE;
143}
144
145int main() {
146 int td, n, a[ L ][ L ], b[ L ][ L ], c[ L ][ L ];
147 scanf( "%d", &td );
148 while ( td-- > 0 ) {
149 top = 0;
150 scanf( "%d", &n );
151 input( a, n );
152 input( b, n );
153 mul( c, a, b, n );
154 output( c, n );
155 }
156 return 0;
157}
158
我的实现有点丑。。。