在看DX过程中,没有想到矩阵做用这么大。以前学的数学基本上忘光了。看到矩阵竞然有点云里来,雾里去!是于沉下来,好好看了一下矩阵,并用C++模板了它。在这里只考虑使用,不考虑性能的问题。代码如下。
1 #ifndef _X_MATRIX_H_
2 #define _X_MATRIX_H_
3 #include <iostream>
4 #include <cassert>
5
6 using namespace std;
7
8 template<class T>
9 class XMatrix
10 {
11 public:
12 //
13 XMatrix()
14 :m_Col(0),m_Row(0),m_Data(0)
15 {}
16
17 XMatrix(int parmaRow,int paramCol)
18 :m_Col(0),m_Row(0),m_Data(0)
19 {
20 assert(paramCol >0 && parmaRow >0);
21 InitMatrix(parmaRow,paramCol);
22 }
23
24 XMatrix(const XMatrix<T> & paramMatrix)
25 :m_Col(0),m_Row(0),m_Data(0)
26 {
27 InitMatrix(paramMatrix.m_Row, paramMatrix.m_Col);
28 SetValue(paramMatrix.m_Data);
29 }
30 //
31 ~XMatrix()
32 {
33 FreeMatrix();
34 }
35 //重载=
36 XMatrix & operator = (const XMatrix<T> & paramMatrix)
37 {
38 if( this != & paramMatrix)
39 {
40 InitMatrix(paramMatrix.m_Row, paramMatrix.m_Col);
41 SetValue(paramMatrix.m_Data);
42 }
43 return *this;
44 }
45 //重载(),用于取矩阵的元素
46 T & operator()(int paramRow, int paramCol)
47 {
48 return GetValue(paramRow, paramCol);
49 }
50
51 const T & operator()(int paramRow, int paramCol) const
52 {
53 return GetValue(paramRow, paramCol);
54 }
55 //
56 void SetValue(const T * paramValueList)
57 {
58 assert(paramValueList != NULL);
59 for(int i = 1; i <= m_Row; i++)
60 {
61 for(int j = 1; j <= m_Col; j++)
62 {
63 SetValue(i,j,*paramValueList);
64 paramValueList ++;
65 }
66 }
67 }
68 void SetValue(int paramRow,int paramCol, const T & paramValue)
69 {
70 GetValue(paramRow,paramCol) = paramValue;
71 }
72 //
73 int GetRow() const
74 {
75 return m_Row;
76 }
77
78 int GetCol() const
79 {
80 return m_Col;
81 }
82 //
83 XMatrix<T> operator * ( const T & paramValue) const
84 {
85 XMatrix<T> r(*this);
86 for(int i = 1; i <= r.GetRow(); i++)
87 {
88 for(int j = 1; j <= r.GetCol(); j++)
89 {
90 r(i,j) = r(i,j) * paramValue;
91 }
92 }
93 return r;
94 }
95
96 XMatrix<T> operator + (const XMatrix<T> & paramMatrix) const
97 {
98 assert(m_Col == paramMatrix.m_Col && m_Row == paramMatrix.m_Row);
99 XMatrix<T> r(*this);
100 for(int i = 1; i <= r.GetRow(); i++)
101 {
102 for(int j = 1; j <= r.GetCol(); j++)
103 {
104 r(i,j) = r(i,j) + paramMatrix(i,j);
105 }
106 }
107 return r;
108 }
109 XMatrix<T> operator - (const XMatrix<T> & paramMatrix) const
110 {
111 assert(m_Col == paramMatrix.m_Col && m_Row == paramMatrix.m_Row);
112 XMatrix<T> r(*this);
113 for(int i = 1; i <= r.GetRow(); i++)
114 {
115 for(int j = 1; j <= r.GetCol(); j++)
116 {
117 r(i,j) = r(i,j) - paramMatrix(i,j);
118 }
119 }
120 return r;
121 }
122 XMatrix<T> operator * ( const XMatrix<T> & paramMatrix ) const
123 {
124 assert(m_Col == paramMatrix.m_Row );
125 XMatrix<T> r(m_Row, paramMatrix.m_Col);
126 r.Clear(0);
127 int n = paramMatrix.m_Row;
128 for(int i = 1; i <= r.GetRow(); i++)
129 {
130 for(int j = 1; j <= r.GetCol(); j++)
131 {
132 for( int k = 1; k <= n; k ++)
133 {
134 r(i,j) = r(i,j) + GetValue(i,k) * paramMatrix(k,j);
135 }
136 }
137 }
138 return r;
139 }
140 void Clear(const T & paramValue)
141 {
142 for(int i = 1; i <= GetRow(); i++)
143 {
144 for(int j = 1; j <= GetCol(); j++)
145 {
146 SetValue(i,j,paramValue);
147 }
148 }
149 }
150
151
152 const T & GetValue(int paramRow,int paramCol) const
153 {
154 assert(paramCol > 0 && paramRow > 0);
155 assert( paramRow <= GetRow() && paramCol <= GetCol() );
156 return m_Data[(paramRow - 1)*m_Col + (paramCol-1)];
157 }
158
159 T & GetValue(int paramRow,int paramCol)
160 {
161 assert(paramCol > 0 && paramRow > 0);
162 assert( paramRow <= GetRow() && paramCol <= GetCol() );
163 return m_Data[(paramRow - 1)*m_Col + (paramCol-1)];
164 }
165
166 bool operator == (const XMatrix<T> ¶mMatrix) const
167 {
168 assert(m_Col == paramMatrix.m_Col && m_Row == paramMatrix.m_Row);
169 for(int i = 1; i <= GetRow(); i++)
170 {
171 for(int j = 1; j <= GetCol(); j++)
172 {
173 if( GetValue(i,j) != paramMatrix(i,j) )
174 {
175 return false;
176 }
177 }
178 }
179 return false;
180 }
181
182 bool operator != (const XMatrix<T> ¶mMatrix) const
183 {
184 return !(*this == paramMatrix);
185 }
186
187 XMatrix<T> Transpose() const //矩阵转置
188 {
189 XMatrix<T> r( GetCol(), GetRow() );
190 for(int i = 1; i <= GetRow(); i++ )
191 {
192 for(int j = 1; j <= GetCol(); j++)
193 {
194 r(j,i) = GetValue(i,j);
195 }
196 }
197 return r;
198 }
199 //
200 void Reset(int parmaRow,int paramCol)
201 {
202 assert(paramCol >0 && parmaRow >0);
203 InitMatrix(parmaRow, paramCol);
204 }
205 private:
206 void InitMatrix(int parmaRow,int paramCol)
207 {
208 FreeMatrix();
209 m_Col = paramCol;
210 m_Row = parmaRow;
211 m_Data = new T[parmaRow * paramCol];
212 Zero();
213 }
214
215 void FreeMatrix()
216 {
217 if(m_Data != NULL)
218 {
219 delete[] m_Data;
220 m_Data = NULL;
221 }
222 m_Col = 0;
223 m_Row = 0;
224 }
225
226 void Zero()
227 {
228 unsigned char * pC = (unsigned char *)m_Data;
229 int iBytes = GetRow() * GetCol() * sizeof(T);
230 for(int i = 0; i < iBytes; i++,pC++)
231 {
232 *pC = 0;
233 }
234 }
235 private:
236 T * m_Data;
237 int m_Col;
238 int m_Row;
239 };
240
241
242 template<class T>
243 XMatrix<T> operator * (const T & paramValue, const XMatrix<T> & paramMatrix)
244 {
245 return paramMatrix * paramValue;
246 }
247
248 template<class T>
249 ostream & operator << (ostream & o,const XMatrix<T> & paramMatrix)
250 {
251 for(int i = 1; i <= paramMatrix.GetRow(); i++)
252 {
253 for(int j = 1; j <= paramMatrix.GetCol(); j++)
254 {
255 o << paramMatrix(i,j);
256 if( j < paramMatrix.GetCol() ) cout<<",";
257 }
258 cout<<endl;
259 }
260 return o;
261 }
262
263 #endif
264
浮点数矩阵定义
typedef XMatrix<float> FloatMatrix;
为了验证,特别写了一组个DX中仿D3DX矩阵生成函数。只需把函数名的X换成D3DX就变成对就的DX的函数定义了。
1 FloatMatrix * XMatrixIndentity( FloatMatrix * paramMatrixOut);
2 FloatMatrix * XMatrixScaling(FloatMatrix * paramMatrixOut,float paramX, float paramY, float paramZ);
3 FloatMatrix * XMatrixTranslation(FloatMatrix * paramMatrixOut,float paramX, float paramY, float paramZ);
4 FloatMatrix * XMatrixMultiply(FloatMatrix * paramMatrixOut, FloatMatrix * paramM1, FloatMatrix * paramM2);
5 FloatMatrix * XMatrixRotationX(FloatMatrix * paramMatrixOut, float paramAngle);
6 FloatMatrix * XMatrixRotationY(FloatMatrix * paramMatrixOut, float paramAngle);
7 FloatMatrix * XMatrixRotationZ(FloatMatrix * paramMatrixOut, float paramAngle);
8 FloatMatrix * XMatrixOrthoOffCenterLH(FloatMatrix * paramMatrixOut, float paramLeft, float paramRight, float paramBottom, float paramTop, float paramZ_Near, float paramZ_Far);
9
10 FloatMatrix * XMatrixScaling(FloatMatrix * paramMatrixOut,float paramX, float paramY, float paramZ)
11 {
12 assert(paramMatrixOut != NULL&& paramMatrixOut->GetCol() == 4 && paramMatrixOut->GetRow() == 4);
13 XMatrixIndentity(paramMatrixOut);
14 paramMatrixOut->SetValue(1,1,paramX);
15 paramMatrixOut->SetValue(2,2,paramY);
16 paramMatrixOut->SetValue(3,3,paramZ);
17 paramMatrixOut->SetValue(4,4,1);
18 return paramMatrixOut;
19 }
20
21 FloatMatrix * XMatrixIndentity( FloatMatrix * paramMatrixOut)
22 {
23 assert(paramMatrixOut != NULL&& paramMatrixOut->GetCol() == 4 && paramMatrixOut->GetRow() == 4);
24 paramMatrixOut->Clear(0.0f);
25 paramMatrixOut->SetValue(1,1,1.0f);
26 paramMatrixOut->SetValue(2,2,1.0f);
27 paramMatrixOut->SetValue(3,3,1.0f);
28 paramMatrixOut->SetValue(4,4,1.0f);
29 return paramMatrixOut;
30 }
31
32 FloatMatrix * XMatrixTranslation(FloatMatrix * paramMatrixOut,float paramX, float paramY, float paramZ)
33 {
34 assert(paramMatrixOut != NULL&& paramMatrixOut->GetCol() == 4 && paramMatrixOut->GetRow() == 4);
35 XMatrixIndentity(paramMatrixOut);
36 paramMatrixOut->SetValue(4,1,paramX);
37 paramMatrixOut->SetValue(4,2,paramY);
38 paramMatrixOut->SetValue(4,3,paramZ);
39 return paramMatrixOut;
40 }
41
42 FloatMatrix * XMatrixMultiply(FloatMatrix * paramMatrixOut, FloatMatrix * paramM1, FloatMatrix * paramM2)
43 {
44 assert(paramMatrixOut != NULL&& paramMatrixOut->GetCol() == 4 && paramMatrixOut->GetRow() == 4);
45 assert(paramM1 != NULL&& paramM1->GetCol() == 4 && paramM1->GetRow() == 4);
46 assert(paramM2 != NULL&& paramM2->GetCol() == 4 && paramM2->GetRow() == 4);
47 *paramMatrixOut = (*paramM1) * (*paramM2);
48 return paramMatrixOut;
49 }
50
51 FloatMatrix * XMatrixRotationX(FloatMatrix * paramMatrixOut, float paramAngle)
52 {
53 assert(paramMatrixOut != NULL&& paramMatrixOut->GetCol() == 4 && paramMatrixOut->GetRow() == 4);
54 XMatrixIndentity(paramMatrixOut);
55 paramMatrixOut->GetValue(2,2) = cosf(paramAngle);
56 paramMatrixOut->GetValue(2,3) = sinf(paramAngle);
57 paramMatrixOut->GetValue(3,2) = -sinf(paramAngle);
58 paramMatrixOut->GetValue(3,3) = cosf(paramAngle);
59 return paramMatrixOut;
60 }
61
62 FloatMatrix * XMatrixRotationY(FloatMatrix * paramMatrixOut, float paramAngle)
63 {
64 assert(paramMatrixOut != NULL&& paramMatrixOut->GetCol() == 4 && paramMatrixOut->GetRow() == 4);
65 XMatrixIndentity(paramMatrixOut);
66 paramMatrixOut->GetValue(1,1) = cosf(paramAngle);
67 paramMatrixOut->GetValue(1,3) = -sinf(paramAngle);
68 paramMatrixOut->GetValue(3,1) = sinf(paramAngle);
69 paramMatrixOut->GetValue(3,3) = cosf(paramAngle);
70 return paramMatrixOut;
71 }
72
73
74 FloatMatrix * XMatrixRotationZ(FloatMatrix * paramMatrixOut, float paramAngle)
75 {
76 assert(paramMatrixOut != NULL&& paramMatrixOut->GetCol() == 4 && paramMatrixOut->GetRow() == 4);
77 XMatrixIndentity(paramMatrixOut);
78 paramMatrixOut->GetValue(1,1) = cosf(paramAngle);
79 paramMatrixOut->GetValue(1,2) = sinf(paramAngle);
80 paramMatrixOut->GetValue(2,1) = -sinf(paramAngle);
81 paramMatrixOut->GetValue(2,2) = cosf(paramAngle);
82 return paramMatrixOut;
83 }
84 // 2/(r-l) 0 0 0
85 // 0 2/(t-b) 0 0
86 // 0 0 1/(zf-zn) 0
87 // (l+r)/(l-r) (t+b)/(b-t) zn/(zn-zf) 1
88 FloatMatrix * XMatrixOrthoOffCenterLH(FloatMatrix * paramMatrixOut, float paramLeft, float paramRight, float paramBottom, float paramTop, float paramZ_Near, float paramZ_Far)
89 {
90 assert(paramMatrixOut != NULL&& paramMatrixOut->GetCol() == 4 && paramMatrixOut->GetRow() == 4);
91 XMatrixIndentity(paramMatrixOut);
92 paramMatrixOut->GetValue(1,1) = 2.0f /( paramRight - paramLeft);
93 paramMatrixOut->GetValue(2,2) = 2.0f /( paramTop - paramBottom);
94 paramMatrixOut->GetValue(3,3) = 1.0f /( paramZ_Far - paramZ_Near);
95 paramMatrixOut->GetValue(4,1) = (paramLeft + paramRight) / (paramLeft - paramRight);
96 paramMatrixOut->GetValue(4,2) = (paramTop + paramBottom) / (paramBottom - paramTop);
97 paramMatrixOut->GetValue(4,3) = paramZ_Near/(paramZ_Near-paramZ_Far);
98 return paramMatrixOut;
99 }