/*
Name: 三元组稀疏矩阵类
始发于goal00001111的专栏;允许自由转载,但必须注明作者和出处
Author: goal00001111
Date: 07-05-10 09:51
Description: 三元组稀疏矩阵类:实现了矩阵的转置,加法和乘法运算
*/
#include <cstdlib>
#include <iostream>
using namespace std;
const int MAXTERMS = 10000;
template <typename T> class SparseMatrix;//前置声明
template <typename T>
class Trituple //三元组类
{
friend class SparseMatrix<T>;
private:
Trituple<T> & operator = (const Trituple<T> & b);
int row, col;
T value;
};
template <typename T> //重载赋值运算符
Trituple<T> & Trituple<T>::operator = (const Trituple<T> & b)
{
if (this != &b)
{
row = b.row;
col = b.col;
value = b.value;
}
return *this;
}
template <typename T>
class SparseMatrix //三元组稀疏矩阵类
{
public:
SparseMatrix(int maxRow, int maxCol, int maxTerms = 0); //构造函数1
SparseMatrix(const T a[], int maxRow, int maxCol); //构造函数2
SparseMatrix(const SparseMatrix<T> & b); //拷贝构造函数
SparseMatrix<T> & operator = (const SparseMatrix<T> & b);//重载赋值运算符
void Display(); //输出三元组
void DisArray();//输出矩阵
SparseMatrix<T> Transpose();//三元组顺序表:转置矩阵
SparseMatrix<T> FastTranspose();//三元组顺序表:快速转置矩阵
SparseMatrix<T> Add(SparseMatrix<T> b);//矩阵相加
SparseMatrix<T> Multiply(SparseMatrix<T> b);//矩阵相乘
private:
int rows, cols, terms;//矩阵的行数,列数和非零元素个数
Trituple<T> smArray[MAXTERMS];//三元组顺序表
};
template <typename T> //构造函数1
SparseMatrix<T>::SparseMatrix(int maxRow, int maxCol, int maxTerms):rows(maxRow), cols(maxCol), terms(maxTerms)
{
}
template <typename T> //构造函数2:将二维数组转换为三元组
SparseMatrix<T>::SparseMatrix(const T a[], int maxRow, int maxCol):rows(maxRow), cols(maxCol), terms(0)
{
for (int i=0; i<rows; i++)
{
for (int j=0; j<cols; j++)
{
if (a[i*cols+j] != 0)
{
smArray[terms].row = i;
smArray[terms].col = j;
smArray[terms++].value = a[i*cols+j];
}
if (terms > MAXTERMS)
return;
}
}
}
template <typename T> //拷贝构造函数
SparseMatrix<T>::SparseMatrix(const SparseMatrix<T> & b)
{
rows = b.rows;
cols = b.cols;
terms = b.terms;
for (int i=0; i<terms; i++)
smArray[i] = b.smArray[i];
}
template <typename T>//重载赋值运算符
SparseMatrix<T> & SparseMatrix<T>::operator = (const SparseMatrix<T> & b)
{
if (this != &b)
{
rows = b.rows;
cols = b.cols;
terms = b.terms;
for (int i=0; i<terms; i++)
smArray[i] = b.smArray[i];
}
return *this;
}
template <typename T>//输出三元组
void SparseMatrix<T>::Display()
{
cout << "rows = " << rows << ", cols = " << cols << ", terms = " << terms << endl;
for (int i=0; i<terms; i++)
cout << i+1 << "(" << smArray[i].row << ", " << smArray[i].col << ", " << smArray[i].value << ")\t";
cout << endl;
}
template <typename T>//输出矩阵:包括非零元素
void SparseMatrix<T>::DisArray()
{
int top = 0;
for (int i=0; i<rows; i++)
{
for (int j=0; j<cols; j++)
{
if (i == smArray[top].row && j == smArray[top].col)
cout << smArray[top++].value << " ";
else
cout << "0 ";
}
cout << endl;
}
cout << endl;
}
template <typename T>
SparseMatrix<T> SparseMatrix<T>::Transpose()//三元组顺序表:转置矩阵
{
SparseMatrix<T> t(cols, rows, terms);
if(terms > 0)
{
int top = 0;
for(int j=0; j<cols; j++) //按照T的行序(M的列序)为主序依次排列
for(int i=0; i<terms; i++)//扫描M的所有元素
if(smArray[i].col == j)
{
t.smArray[top].row = smArray[i].col;
t.smArray[top].col = smArray[i].row;
t.smArray[top++].value = smArray[i].value;
}//if
} //else
return t;
}
template <typename T>
SparseMatrix<T> SparseMatrix<T>::FastTranspose()//三元组顺序表:快速转置矩阵
{
SparseMatrix<T> t(cols, rows, terms);
int * colSize = new int[cols];//存储每列的非零元素个数
int * colStart = new int[cols];//存储每列第一个非零元素在三元组中的位置(下标)
if(terms > 0)
{
for(int i=0; i<cols; i++)
colSize[i] = 0;
for(int i=0; i<terms; i++)//扫描M的所有元素
colSize[smArray[i].col]++; //确定矩阵M每一列中非零元素的个数
colStart[0] = 0;// 确定矩阵M第i列中第一个非零元素在t.smArray中的位置
for(int i=1; i<cols; i++)
colStart[i] = colStart[i-1] + colSize[i-1];
for(int i=0; i<terms; i++)//扫描M的所有元素
{
int k = colStart[smArray[i].col]; //k即矩阵M第col列中第一个非零元素在t.smArray中的位置
t.smArray[k].row = smArray[i].col;
t.smArray[k].col = smArray[i].row;
t.smArray[k].value = smArray[i].value;
colStart[smArray[i].col]++; //矩阵M第col列中第一个非零元素在t.smArray中的位置向前移动一位
}//for
}//if
delete []colSize;
delete []colStart;
return t;
}
template <typename T>
SparseMatrix<T> SparseMatrix<T>::Add(SparseMatrix<T> b)//矩阵相加:采用合并算法,行序优先
{
SparseMatrix<int> c(cols, rows);
int i = 0, j = 0, k = 0;
while (i < terms && j < b.terms)
{
if (smArray[i].row == b.smArray[j].row && smArray[i].col == b.smArray[j].col)
{
c.smArray[k].row = smArray[i].col;
c.smArray[k].col = smArray[i].row;
c.smArray[k].value = smArray[i++].value + b.smArray[j++].value;
if (c.smArray[k].value != 0)
k++;
}
else if ((smArray[i].row < b.smArray[j].row) ||(smArray[i].row == b.smArray[j].row && smArray[i].col < b.smArray[j].col))
c.smArray[k++] = smArray[i++];
else
c.smArray[k++] = b.smArray[j++];
}//while
if (i > terms) //A结束,若B还有元素,则将B的元素直接放入C中
{
while (j < b.terms)
c.smArray[k++] = b.smArray[j++];
}
else //B结束,若A还有元素,则将A的元素直接放入C中
{
while (i < terms)
c.smArray[k++] = smArray[i++];
}
c.terms = k;
return c;
}
template <typename T> //矩阵相乘:快速乘法
SparseMatrix<T> SparseMatrix<T>::Multiply(SparseMatrix<T> b)
{
SparseMatrix<T> t(0, 0 , 0);
if(b.rows != cols)
return t;
SparseMatrix<T> c(rows, b.cols);
int * rowSize = new int[b.rows]; //存储b每行的非零元素个数
int * rowStart = new int[b.rows];//存储b每行的首个非零元素位置
int * ctemp = new int[b.cols]; //存储b中与a某个元素做乘法运算的第col列元素的累积值
if(terms * b.terms != 0)//若C是非零矩阵
{
for(int i=0; i<b.rows; i++)
rowSize[i] = 0;
for(int i=0; i<b.terms; i++)//扫描b的所有元素
rowSize[b.smArray[i].row]++; //确定矩阵b每一行中非零元素的个数
rowStart[0] = 0;// 确定矩阵b第i行中第一个非零元素在b.smArray中的位置
for(int i=1; i<b.rows; i++)
rowStart[i] = rowStart[i-1] + rowSize[i-1];
int k = 0, top = 0;
for(int i=0; i<rows; i++)//对A进行逐行处理,即对C进行逐行处理
{
for(int j=0; j<b.cols; j++)//当前各元素累加器清零
ctemp[j] = 0;
while (k < terms && smArray[k].row == i)//处理A的第i行元素
{
int tc = smArray[k].col;
for (int j=rowStart[tc]; b.smArray[j].row == tc; j++)//处理B的第tc行数据:进行累积
{
ctemp[b.smArray[j].col] += smArray[k].value * b.smArray[j].value;
}
k++;
}
for(int j=0; j<b.cols; j++)//得到C的第i行中每一个元素的值
{
if(ctemp[j] != 0)//压缩存储该行非零元
{
if(top == MAXTERMS)
return t;
c.smArray[top].row = i;
c.smArray[top].col = j;
c.smArray[top++].value = ctemp[j];//C的元素值等于A的行数ctemp[j]的累积值
} //if(ctemp[j] != 0)
}
}//for arow
c.terms = top;
}
delete []rowSize;
delete []rowStart;
delete []ctemp;
return c;
} // MultSMatrix
int main(int argc, char *argv[])
{
SparseMatrix<int> obja(2, 3);
obja.Display();
int arr[100] = {0};
int r = 3, c = 3;
for (int i=0; i<r; i++)
for (int j=0; j<c; j++)
cin >> arr[i*c+j];
//
// for (int i=0; i<r; i++)
// {
// for (int j=0; j<c; j++)
// cout << arr[i*c+j] << " ";
// cout << endl;
// }
// cout << endl;
SparseMatrix<int> objb(arr, r, c);
objb.DisArray();
SparseMatrix<int> objc = objb.Transpose();
objc.DisArray();
//obja = objc;
// objc.Display();
// objc.DisArray();
//
// SparseMatrix<int> objd = obja.Transpose();
// objd.Display();
// objd.DisArray();
//
SparseMatrix<int> obje = objb.Multiply(objc);
obje.Display();
obje.DisArray();
system("PAUSE");
return EXIT_SUCCESS;
}