posts - 183,  comments - 10,  trackbacks - 0

K-近邻法的实现

K-近邻法是根据距离最近的 k 个样例的类型来推测该样例的类型。
实现中中主要的环节有:
·训练样例格式和测试样例格式的定义
·样例结构体的定义
·训练样例和测试样例的读取
·样例距离的计算,欧氏距离
·距离矩阵的生成
·针对每个测试样例,得到其到每个训练样例的距离,根据距离由小到大排序,更具距离和权重成反比的关系,计算每个类型的总的权重,得到最大权重的那个类型,即是当前测试样例的类型
·k 值的选择和设定

训练样本的格式是:
每行代表一个样例,每行的第一个元素是该样例的类型,后面是该样例的特征向量
例如:

train.txt
a    1 2 3 4 5
b    5 4 3 2 1
c    3 3 3 3 3
d    -3 -3 -3 -3 -3
a    1 2 3 4 4
b    4 4 3 2 1
c    3 3 3 2 4
d    0 0 1 1 -2

 


测试样例的格式是:
每行代表一个样例,每行即是该样例的特征向量
例如:
test.txt
1 2 3 2 4
2 3 4 2 1
8 7 2 3 5
-3 -2 2 4 0
-4 -4 -4 -4 -4
1 2 3 4 4
4 4 3 2 1
3 3 3 2 4
0 0 1 1 -2

测试样例的输出结果的格式是:
格式与训练样例一样,即每行代表一个样例,每行的第一个元素是学习到的测试样例的类型,后面是该样例的特征向量
例如:
result.txt
a    1 2 3 2 4 
b    2 3 4 2 1 
b    8 7 2 3 5 
a    -3 -2 2 4 0 
d    -4 -4 -4 -4 -4 
a    1 2 3 4 4 
b    4 4 3 2 1 
c    3 3 3 2 4 
d    0 0 1 1 -2 

具体的程序实现如下:

  1 /*
  2     K-近邻法的实现
  3     mark
  4     goonyangxiaofang@163.com
  5     QQ 591 247 876
  6     2012.02.13
  7 */
  8 
  9 
 10 #include <iostream>
 11 #include <string>
 12 #include <vector>
 13 #include <set>
 14 #include <map>
 15 #include <fstream>
 16 #include <sstream>
 17 #include <cassert>
 18 #include <cmath>
 19 using namespace std;
 20 
 21 // 样例结构体,所属类型和特征向量
 22 struct sample
 23 {
 24     string type;
 25     vector<double> features;
 26 };
 27 
 28 // 类型和距离结构体,未用到
 29 struct typeDistance
 30 {
 31     string type;
 32     double distance;
 33 };
 34 
 35 bool operator < (const typeDistance& lhs, const typeDistance& rhs)
 36 {
 37     return lhs.distance < rhs.distance;
 38 }
 39 
 40 // 读取训练样本
 41 // 训练样本的格式是:每行代表一个样例
 42 // 每行的第一个元素是类型名,后面的是样例的特征向量
 43 // 例如:
 44 /*
 45 a    1 2 3 4 5
 46 b    5 4 3 2 1
 47 c    3 3 3 3 3
 48 d    -3 -3 -3 -3 -3
 49 a    1 2 3 4 4
 50 b    4 4 3 2 1
 51 c    3 3 3 2 4
 52 d    0 0 1 1 -2
 53 */
 54 void readTrain(vector<sample>& train, const string& file)
 55 {
 56     ifstream fin(file.c_str());
 57     if (!fin)
 58     {
 59         cerr << "File error!" << endl;
 60         exit(1);
 61     }
 62     string line;
 63     double d = 0.0;
 64     while (getline(fin, line))
 65     {
 66         istringstream sin(line);
 67         sample ts;
 68         sin >> ts.type;
 69         while (sin >> d)
 70         {
 71             ts.features.push_back(d);
 72         }
 73         train.push_back(ts);
 74     }
 75     fin.close();
 76 }
 77 
 78 // 读取测试样本
 79 // 每行代表一个样例
 80 // 每一行是一个样例的特征向量
 81 // 例如:
 82 /*
 83 1 2 3 2 4
 84 2 3 4 2 1
 85 8 7 2 3 5
 86 -3 -2 2 4 0
 87 -4 -4 -4 -4 -4
 88 1 2 3 4 4
 89 4 4 3 2 1
 90 3 3 3 2 4
 91 0 0 1 1 -2
 92 */
 93 void readTest(vector<sample>& test, const string& file)
 94 {
 95     ifstream fin(file.c_str());
 96     if (!fin)
 97     {
 98         cerr << "File error!" << endl;
 99         exit(1);
100     }
101     double d = 0.0;
102     string line;
103     while (getline(fin, line))
104     {
105         istringstream sin(line);
106         sample ts;
107         while (sin >> d)
108         {
109             ts.features.push_back(d);
110         }
111         test.push_back(ts);
112     }
113     fin.close();
114 }
115 
116 // 计算欧氏距离
117 double euclideanDistance(const vector<double>& v1, const vector<double>& v2)
118 {
119     assert(v1.size() == v2.size());
120     double ret = 0.0;
121     for (vector<double>::size_type i = 0; i != v1.size(); ++i)
122     {
123         ret += (v1[i] - v2[i]) * (v1[i] - v2[i]);
124     }
125     return sqrt(ret);
126 }
127 
128 // 初始化距离矩阵
129 // 该矩阵是根据训练样本和测试样本而得
130 // 矩阵的行数为测试样本的数目,列数为训练样本的数目
131 void initDistanceMatrix(vector<vector<double> >& dm, const vector<sample>& train, const vector<sample>& test)
132 {
133     for (vector<sample>::size_type i = 0; i != test.size(); ++i)
134     {
135         vector<double> vd;
136         for (vector<sample>::size_type j = 0; j != train.size(); ++j)
137         {
138             vd.push_back(euclideanDistance(test[i].features, train[j].features));
139         }
140         dm.push_back(vd);
141     }
142 }
143 
144 // K-近邻法的实现
145 // 设定不同的 k 值,给每个测试样例予以一个类型
146 // 距离和权重成反比
147 void knnProcess(vector<sample>& test, const vector<sample>& train, const vector<vector<double> >& dm, unsigned int k)
148 {
149     for (vector<sample>::size_type i = 0; i != test.size(); ++i)
150     {
151         multimap<doublestring> dts;
152         for (vector<double>::size_type j = 0; j != dm[i].size(); ++j)
153         {
154             if (dts.size() < k)
155             {
156                 dts.insert(make_pair(dm[i][j], train[j].type));
157             }
158             else
159             {
160                 multimap<doublestring>::iterator it = dts.end();
161                 --it;
162                 if (dm[i][j] < it->first)
163                 {
164                     dts.erase(it);
165                     dts.insert(make_pair(dm[i][j], train[j].type));
166                 }
167             }
168         }
169         map<stringdouble> tds;
170         string type = "";
171         double weight = 0.0;
172         for (multimap<doublestring>::const_iterator cit = dts.begin(); cit != dts.end(); ++cit)
173         {
174             // 不考虑权重的情况,在 k 个样例中只要出现就加 1
175             // ++tds[cit->second];
176 
177             // 这里是考虑距离与权重的关系,距离越大权重越小
178             tds[cit->second] += 1.0 / cit->first;
179             if (tds[cit->second] > weight)
180             {
181                 weight = tds[cit->second];
182                 type = cit->second;
183             }
184         }
185         test[i].type = type;
186     }
187 }
188 
189 // 输出结果
190 // 输出的格式和训练样本的格式一样
191 // 每行表示一个样例,第一个元素是该样例的类型,后面是该样例的特征向量
192 // 例如:
193 /*
194 a    1 2 3 2 4 
195 b    2 3 4 2 1 
196 b    8 7 2 3 5 
197 a    -3 -2 2 4 0 
198 d    -4 -4 -4 -4 -4 
199 a    1 2 3 4 4 
200 b    4 4 3 2 1 
201 c    3 3 3 2 4 
202 d    0 0 1 1 -2 
203 */
204 void writeTest(const vector<sample>& test, const string& file)
205 {
206     ofstream fout(file.c_str());
207     if (!fout)
208     {
209         cerr << "File error!" << endl;
210         exit(1);
211     }
212     for (vector<sample>::size_type i = 0; i != test.size(); ++i)
213     {
214         fout << test[i].type << '\t';
215         for (vector<double>::size_type j = 0; j != test[i].features.size(); ++j)
216         {
217             fout << test[i].features[j] << ' ';
218         }
219         fout << endl;
220     }
221 }
222 
223 // 封装
224 void knn(const string& file1, const string& file2, const string& file3, int k)
225 {
226     vector<sample> train, test;
227     readTrain(train, file1.c_str());
228     readTest(test, file2.c_str());
229     vector<vector<double> > dm;
230     initDistanceMatrix(dm, train, test);
231     knnProcess(test, train, dm, k);
232     writeTest(test, file3.c_str());
233 }
234 
235 // 测试
236 int main()
237 {
238     knn("train.txt""test.txt""result.txt"5);
239     return 0;
240 }
241 




posted on 2012-02-14 09:47 unixfy 阅读(4907) 评论(0)  编辑 收藏 引用

只有注册用户登录后才能发表评论。
网站导航: 博客园   IT新闻   BlogJava   博问   Chat2DB   管理