题目的意思是:
给一个有n个点,m条边的无向图
两点之间可以存在多条边
现在每次随机增加一条边
问使得全部点都连通需要增加多少次(期望值)
首先,求出所有连通分量。用并查集。
每次随机增加一条边的时候一共有两种情况:
1)这条边连接了两个不同的连通分量,它的概率是p
2)这条边在一个连通分量里,它的概率是q = 1 - p
前者可以改变连通分量的数量,后者不能
如果把当前图的状态视为一个子问题
那么就可以用动态规划解决问题了
图的状态可以表示为:有多少个连通分量,每个连通分量包含多少个点
比如说图的状态 (2, 3, 3) 表示有三个连通分量,每个连通分量包含的点的个数分别为 2, 3, 3
动态规划的转移方程为:
f = p*(1+r) + p*q*(2+r) + p*q^2*(3+r) ....
其中r为p发生后,新状态的期望值
这个东西高中的时候学过,呵呵。
而1)中也包含多种情况,需要两两枚举
最大的问题是,f的值是一个无限数列,它的极值很难求。但无论如何,有高手求出来了。。在这里:http://archive.cnblogs.com/a/1325929/
它的极值是 f = p * (1 / (1 - q) + r) / (1 - q)
我对照了一下标程,确实是这个。
后来我自己推导了一下,发现它可以化成多个等比数列相加的形式,求和后,发现当n趋向于无穷大的时候,它的极限就是上面这个公式。
(注意:i*q^i, 当0<q<1,i趋向于无穷大的时候等于0)
这样程序就可以写了。动态规划保存每个图的状态。
如果用python写,只要建立一个tuple到float的映射就可以了。非常方便。
java中也有List<int>到Double的映射。
c里面估计就得用hash了。
py代码,参照标程写的。
fi = open('in')
fo = open('out')
dp = {():0}
ti = 0
def get(s):
if s in dp:
return dp[s]
q = sum([i*(i-1) for i in s])*1.0/2/nn
res = 0
for i in range(len(s)):
for j in range(len(s)):
if i < j:
l = list(s)
del l[max(i,j)]
del l[min(i,j)]
l.append(s[i]+s[j])
l.sort()
r = get(tuple(l))
p = s[i]*s[j]*1.0/nn
res += p*(1+r-r*q)/pow(1-q,2)
dp[s] = res
return res
while 1:
a = fi.readline().split()
if a == None or len(a) != 2:
break
N, M = int(a[0]), int(a[1])
nn = N*(N-1)/2
s = [ i for i in range(N) ]
for i in range(M):
u, v = [ int(i) for i in fi.readline().split() ]
u -= 1
v -= 1
k = s[u]
for j in range(N):
if s[j] == k:
s[j] = s[v]
ss = [ s.count(i) for i in set(s) ]
ss.sort()
print '----', ti
mine = get(tuple(ss))
ans = float(fo.readline().strip())
print 'mine', mine, 'ans', ans
print len(dp)
ti += 1
标程
用很简洁的代码写了并查集,值得借鉴!
import java.util.*;
import java.io.File;
import java.io.PrintWriter;
import java.io.FileNotFoundException;
public class interconnect_pm {
private static int nn;
public static void main(String[] args) throws FileNotFoundException {
Scanner in = new Scanner(new File("in"));
PrintWriter out = new PrintWriter("ans.out");
int n = in.nextInt();
nn = (n * (n - 1)) / 2;
int m = in.nextInt();
int[] p = new int[n];
for (int i = 0; i < n; i++) p[i] = i;
for (int i = 0; i < m; i++) {
int u = in.nextInt();
int v = in.nextInt();
u--;
v--;
int k = p[u];
for (int j = 0; j < n; j++) {
if (p[j] == k) {
p[j] = p[v];
}
}
}
List<Integer> st = new ArrayList<Integer>();
for (int i = 0; i < n; i++) {
int s = 0;
for (int j = 0; j < n; j++) {
if (p[j] == i) s++;
}
if (s > 0) {
st.add(s);
}
}
Collections.sort(st);
List<Integer> fn = new ArrayList<Integer>();
fn.add(n);
mem.put(fn, 0.0);
out.println(get(st));
System.out.println(mem.size());
out.close();
}
static Map<List<Integer>, Double> mem = new HashMap<List<Integer>, Double>();
private static double get(List<Integer> st) {
Double ret = mem.get(st);
if (ret != null) return ret;
int m = st.size();
int[][] a = new int[m][m];
for (int i = 0; i < m; i++) {
for (int j = i + 1; j < m; j++) {
a[i][j] = st.get(i) * st.get(j);
}
}
int s = 0;
for (int i = 0; i < m; i++) {
s += st.get(i) * (st.get(i) - 1) / 2;
}
double res = 0;
for (int i = 0; i < m; i++) {
for (int j = i + 1; j < m; j++) {
List<Integer> ss = new ArrayList<Integer>(st.size() - 1);
boolean q = true;
int z = st.get(i) + st.get(j);
for (int k = 0; k < m; k++) {
if (k != i && k != j) {
int zz = st.get(k);
if (q && zz >= z) {
q = false;
ss.add(z);
}
ss.add(zz);
}
}
if (q)
ss.add(z);
double p = a[i][j] * 1.0 / (nn - s);
double e = a[i][j] * 1.0 / ((1 - s * 1.0 / nn) * (1 - s * 1.0 / nn) * nn);
e = e + get(ss) * p;
res += e;
}
}
System.out.println(st);
mem.put(st, res);
return res;
}
}