方阵相乘
Time Limit:1000MS Memory Limit:30000KB
Description
实现两个n*n 方阵相乘的Strassen 算法,这里假设 n 为 2 的方幂。
Input
第一行为一个正整数N,表示有几组测试数据。
每组测试数据的第一行为一个正整数n(1<=n<=100),n为2的方幂,表示方阵n*n
接下去的n行表示第一个方阵,每行有n个整数,用空格分开。
再接下去的n行表示第二个方阵,每行有n个整数,用空格分开。
Output
对于每组测试出据,输出n行,每行有n个整数,用空格分开,不能有多余的空格。
Sample Input
1
2
1 2
3 4
5 6
7 8
Sample Output
19 22
43 50
朴素的矩阵乘法
1
#include <iostream>
2
#include <cstdio>
3
4
using namespace std;
5
6
const int L = 103;
7
8
int a[ L ][ L ], b[ L ][ L ], c[ L ][ L ];
9
10
int main()
{
11
int td, n, i, j, k, tmp;
12
scanf( "%d", &td );
13
while ( td-- )
{
14
scanf( "%d", &n );
15
for ( i = 0; i < n; ++i )
16
for ( j = 0; j < n; ++j )
17
scanf( "%d", &a[ i ][ j ] );
18
for ( i = 0; i < n; ++i )
19
for ( j = 0; j < n; ++j )
20
scanf( "%d", &b[ i ][ j ] );
21
for ( i = 0; i < n; ++i )
22
for ( j = 0; j < n; ++j )
{
23
tmp = 0;
24
for ( k = 0; k < n; ++k )
25
tmp += a[ i ][ k ] * b[ k ][ j ];
26
c[ i ][ j ] = tmp;
27
}
28
for ( i = 0; i < n; ++i )
{
29
printf( "%d", c[ i ][ 0 ] );
30
for ( j = 1; j < n; ++j )
31
printf( " %d", c[ i ][ j ] );
32
printf( "\n" );
33
}
34
}
35
return 0;
36
}
37
Strassen 算法
1
#include <iostream>
2
#include <cstdio>
3
4
using namespace std;
5
6
#define L 102
7
#define LIM 400
8
9
typedef int Mat[ L ][ L ];
10
11
Mat buf[ LIM ];
12
int top;
13
14
void input( int a[][L], int n )
{
15
int i, j;
16
for ( i = 1; i <= n; ++i )
{
17
for ( j = 1; j <= n; ++j )
{
18
scanf( "%d", &a[ i ][ j ] );
19
}
20
}
21
}
22
23
void output( int c[][L], int n )
{
24
int i, j;
25
for ( i = 1; i <= n; ++i )
{
26
for ( j = 1; j < n; ++j )
{
27
printf( "%d ", c[ i ][ j ] );
28
}
29
printf( "%d\n", c[ i ][ j ] );
30
}
31
}
32
33
void get( int a[][L], int a11[][L], int a12[][L], int a21[][L], int a22[][L], int n )
{
34
int i, j;
35
for ( i = 1; i <= n; ++i )
{
36
for ( j = 1; j <= n; ++j )
{
37
a11[ i ][ j ] = a[ i ][ j ];
38
a12[ i ][ j ] = a[ i ][ j + n ];
39
a21[ i ][ j ] = a[ i + n ][ j ];
40
a22[ i ][ j ] = a[ i + n ][ j + n ];
41
}
42
}
43
}
44
45
void put( int a[][L], int a11[][L], int a12[][L], int a21[][L], int a22[][L], int n )
{
46
int i, j;
47
for ( i = 1; i <= n; ++i )
{
48
for ( j = 1; j <= n; ++j )
{
49
a[ i ][ j ] = a11[ i ][ j ];
50
a[ i ][ j + n ] = a12[ i ][ j ];
51
a[ i + n ][ j ] = a21[ i ][ j ];
52
a[ i + n ][ j + n ] = a22[ i ][ j ];
53
}
54
}
55
}
56
57
void add( int c[][L], int a[][L], int b[][L], int n )
{
58
int i, j;
59
for ( i = 1; i <= n; ++i )
{
60
for ( j = 1; j <= n; ++j )
{
61
c[ i ][ j ] = a[ i ][ j ] + b[ i ][ j ];
62
}
63
}
64
}
65
66
void sub( int c[][L], int a[][L], int b[][L], int n )
{
67
int i, j;
68
for ( i = 1; i <= n; ++i )
{
69
for ( j = 1; j <= n; ++j )
{
70
c[ i ][ j ] = a[ i ][ j ] - b[ i ][ j ];
71
}
72
}
73
}
74
75
void mul( int c[][L], int a[][L], int b[][L], int n )
{
76
#define ADD(m) Mat &m = buf[ top++ ]
77
#define ADDS(a) ADD(a##11); ADD(a##12); ADD(a##21); ADD(a##22)
78
#define ENTER ADDS(a); ADDS(b); ADDS(c); ADD(d1); ADD(d2); ADD(d3); ADD(d4); ADD(d5); ADD(d6); ADD(d7); ADD(t1); ADD(t2)
79
#define LEAVE top -= 21
80
81
ENTER;
82
83
if ( top >= LIM )
{
84
// for debug
85
fprintf( stderr, "buf overflow!!" );
86
LEAVE;
87
return;
88
}
89
90
91
if ( n < 1 )
{
92
LEAVE;
93
return;
94
}
95
if ( n == 1 )
{
96
c[ 1 ][ 1 ] = a[ 1 ][ 1 ] * b[ 1 ][ 1 ];
97
LEAVE;
98
return;
99
}
100
n >>= 1;
101
get( a, a11, a12, a21, a22, n );
102
get( b, b11, b12, b21, b22, n );
103
104
add( t1, a11, a22, n );
105
add( t2, b11, b22, n );
106
mul( d1, t1, t2, n );
107
108
add( t1, a21, a22, n );
109
mul( d2, t1, b11, n );
110
111
sub( t2, b12, b22, n );
112
mul( d3, a11, t2, n );
113
114
sub( t2, b21, b11, n );
115
mul( d4, a22, t2, n );
116
117
add( t1, a11, a12, n );
118
mul( d5, t1, b22, n );
119
120
sub( t1, a21, a11, n );
121
add( t2, b11, b12, n );
122
mul( d6, t1, t2, n );
123
124
sub( t1, a12, a22, n );
125
add( t2, b21, b22, n );
126
mul( d7, t1, t2, n );
127
128
add( t1, d1, d4, n );
129
sub( t2, d5, d7, n );
130
sub( c11, t1, t2, n );
131
132
add( c12, d3, d5, n );
133
134
add( c21, d2, d4, n );
135
136
add( t1, d1, d3, n );
137
sub( t2, d2, d6, n );
138
sub( c22, t1, t2, n );
139
140
put( c, c11, c12, c21, c22, n );
141
142
LEAVE;
143
}
144
145
int main()
{
146
int td, n, a[ L ][ L ], b[ L ][ L ], c[ L ][ L ];
147
scanf( "%d", &td );
148
while ( td-- > 0 )
{
149
top = 0;
150
scanf( "%d", &n );
151
input( a, n );
152
input( b, n );
153
mul( c, a, b, n );
154
output( c, n );
155
}
156
return 0;
157
}
158
我的实现有点丑。。。