【题意】:给出n个点,m条边,边分为两种,一种是A公司的,一种是B公司的。边上有权值,问用n-1条边把n个点连起来的最小费用是多少,其中A公司的边刚好有k条。题目保证有解。
【题解】:很明显看到是求一颗最小生成树,不过有一个限制就是刚刚好有k条边是A公司的。想了很久不会做,看别人代码的。二分出一个最大值delta使得A公司的边加上这个值后再求MST时A公司的边有大于等于k条,然后答案就是cost of MST - k * delta。思想其实是加上一个delta值去逼近答案,最后可以求出这样的MST,如果最后求出的MST的A公司的边多于k条,一定存在与A公司等效且等价的B公司边,替换过来即可。
【代码】:
1 #include "iostream"
2 #include "cstdio"
3 #include "cstring"
4 #include "algorithm"
5 #include "vector"
6 #include "queue"
7 #include "cmath"
8 #include "string"
9 #include "cctype"
10 #include "map"
11 #include "iomanip"
12 #include "set"
13 #include "utility"
14 using namespace std;
15 typedef pair<int, int> pii;
16 #define pb push_back
17 #define mp make_pair
18 #define fi first
19 #define se second
20 #define sof(x) sizeof(x)
21 #define lc(x) (x << 1)
22 #define rc(x) (x << 1 | 1)
23 #define lowbit(x) (x & (-x))
24 #define ll long long
25 struct Edge {
26 int u, v, w, id;
27 Edge(){}
28 Edge(int _u, int _v, int _w, int _id) {
29 u = _u, v = _v, w = _w, id = _id;
30 }
31 bool operator<(const Edge &x) const {
32 if(w != x.w) return w < x.w;
33 else return id < x.id;
34 }
35 }et[2][100050], e;
36 int tot, tot1;
37 int n, m, k, cost;
38 int fa[50050];
39
40 int find(int x) {
41 return (x == fa[x]) ? x : fa[x] = find(fa[x]);
42 }
43
44 bool merge(int u, int v) {
45 u = find(fa[u]), v = find(fa[v]);
46 if(u != v) {
47 fa[u] = v;
48 return true;
49 } else return false;
50 }
51
52 bool check(int w) {
53 cost = 0;
54 int cnt = 0;
55 for(int i = 0; i < n; i++) fa[i] = i;
56 int i = 0, j = 0;
57 while(i < tot || j < tot1) {
58 if(et[0][i].w + w <= et[1][j].w) {
59 e = et[0][i++];
60 e.w += w;
61 } else e = et[1][j++];
62 if(merge(e.u, e.v)) {
63 if(!e.id) cnt++;
64 cost += e.w;
65 }
66 }
67 return cnt >= k;
68 }
69
70 int main() {
71 int Case = 1;
72 while(~scanf("%d%d%d", &n, &m, &k)) {
73 tot = tot1 = 0;
74 for(int i = 0; i < m; i++) {
75 int u, v, w, id;
76 scanf("%d%d%d%d", &u, &v, &w, &id);
77 if(id) et[1][tot1++] = Edge(u, v, w, id);
78 else et[0][tot++] = Edge(u, v, w, id);
79 }
80 sort(et[0], et[0] + tot);
81 sort(et[1], et[1] + tot1);
82 et[0][tot].w = et[1][tot1].w = 1 << 30;
83 int l = -100, r = 100;
84 int w;
85 while(l <= r) {
86 int mid = (l + r) / 2;
87 if(check(mid)) w = mid, l = mid + 1;
88 else r = mid - 1;
89 }
90 check(w);
91 printf("Case %d: %d\n", Case++, cost - w * k);
92 }
93 return 0;
94 }
95