糯米

TI DaVinci, gstreamer, ffmpeg
随笔 - 167, 文章 - 0, 评论 - 47, 引用 - 0
数据加载中……

POJ 3156 Interconnect 图论+数论

题目的意思是:
给一个有n个点,m条边的无向图
两点之间可以存在多条边
现在每次随机增加一条边
问使得全部点都连通需要增加多少次(期望值)

首先,求出所有连通分量。用并查集。
每次随机增加一条边的时候一共有两种情况:
1)这条边连接了两个不同的连通分量,它的概率是p
2)这条边在一个连通分量里,它的概率是q = 1 - p
前者可以改变连通分量的数量,后者不能

如果把当前图的状态视为一个子问题
那么就可以用动态规划解决问题了
图的状态可以表示为:有多少个连通分量,每个连通分量包含多少个点
比如说图的状态 (2, 3, 3) 表示有三个连通分量,每个连通分量包含的点的个数分别为 2, 3, 3

动态规划的转移方程为:
f = p*(1+r) + p*q*(2+r) + p*q^2*(3+r) ....
其中r为p发生后,新状态的期望值
这个东西高中的时候学过,呵呵。

而1)中也包含多种情况,需要两两枚举
最大的问题是,f的值是一个无限数列,它的极值很难求。但无论如何,有高手求出来了。。在这里:http://archive.cnblogs.com/a/1325929/
它的极值是 f = p * (1 / (1 - q) + r) / (1 - q)
我对照了一下标程,确实是这个。
后来我自己推导了一下,发现它可以化成多个等比数列相加的形式,求和后,发现当n趋向于无穷大的时候,它的极限就是上面这个公式。
(注意:i*q^i,  当0<q<1,i趋向于无穷大的时候等于0)

这样程序就可以写了。动态规划保存每个图的状态。
如果用python写,只要建立一个tuple到float的映射就可以了。非常方便。
java中也有List<int>到Double的映射。
c里面估计就得用hash了。

py代码,参照标程写的。


fi 
= open('in')
fo 
= open('out')
dp 
= {():0}
ti 
= 0

def get(s):
    
if s in dp:
        
return dp[s]
    q 
= sum([i*(i-1for i in s])*1.0/2/nn
    res 
= 0
    
for i in range(len(s)):
        
for j in range(len(s)):
            
if i < j:
                l 
= list(s)
                
del l[max(i,j)]
                
del l[min(i,j)]
                l.append(s[i]
+s[j])
                l.sort()
                r 
= get(tuple(l))
                p 
= s[i]*s[j]*1.0/nn
                res 
+= p*(1+r-r*q)/pow(1-q,2)
    dp[s] 
= res
    
return res

while 1:
    a 
= fi.readline().split()
    
if a == None or len(a) != 2:
        
break
    N, M 
= int(a[0]), int(a[1])
    nn 
= N*(N-1)/2
    s 
= [ i for i in range(N) ]
    
for i in range(M):
        u, v 
= [ int(i) for i in fi.readline().split() ]
        u 
-= 1
        v 
-= 1
        k 
= s[u]
        
for j in range(N):
            
if s[j] == k:
                s[j] 
= s[v]
    ss 
= [ s.count(i) for i in set(s) ]
    ss.sort()
    
print '----', ti
    mine 
= get(tuple(ss))
    ans 
= float(fo.readline().strip())
    
print 'mine', mine, 'ans', ans
    
print len(dp)
    ti 
+= 1
    


标程
用很简洁的代码写了并查集,值得借鉴!

import java.util.*;
import java.io.File;
import java.io.PrintWriter;
import java.io.FileNotFoundException;

public class interconnect_pm {

    
private static int nn;

    
public static void main(String[] args) throws FileNotFoundException {
        Scanner in 
= new Scanner(new File("in"));
        PrintWriter out 
= new PrintWriter("ans.out");
        
int n = in.nextInt();
        nn 
= (n * (n - 1)) / 2;
        
int m = in.nextInt();
        
int[] p = new int[n];
        
for (int i = 0; i < n; i++) p[i] = i;
        
for (int i = 0; i < m; i++) {
            
int u = in.nextInt();
            
int v = in.nextInt();
            u
--;
            v
--;
            
int k = p[u];
            
for (int j = 0; j < n; j++) {
                
if (p[j] == k) {
                    p[j] 
= p[v];
                }
            }
        }
        List
<Integer> st = new ArrayList<Integer>();
        
for (int i = 0; i < n; i++) {
            
int s = 0;
            
for (int j = 0; j < n; j++) {
                
if (p[j] == i) s++;
            }
            
if (s > 0) {
                st.add(s);
            }
        }
        Collections.sort(st);
        List
<Integer> fn = new ArrayList<Integer>();
        fn.add(n);
        mem.put(fn, 
0.0);
        out.println(get(st));
        System.out.println(mem.size());
        out.close();
    }

    
static Map<List<Integer>, Double> mem = new HashMap<List<Integer>, Double>();

    
private static double get(List<Integer> st) {
        Double ret 
= mem.get(st);
        
if (ret != nullreturn ret;
        
int m = st.size();
        
int[][] a = new int[m][m];
        
for (int i = 0; i < m; i++) {
            
for (int j = i + 1; j < m; j++) {
                a[i][j] 
= st.get(i) * st.get(j);
            }
        }
        
int s = 0;
        
for (int i = 0; i < m; i++) {
            s 
+= st.get(i) * (st.get(i) - 1/ 2;
        }
        
double res = 0;
        
for (int i = 0; i < m; i++) {
            
for (int j = i + 1; j < m; j++) {
                List
<Integer> ss = new ArrayList<Integer>(st.size() - 1);
                
boolean q = true;
                
int z = st.get(i) + st.get(j);
                
for (int k = 0; k < m; k++) {
                    
if (k != i && k != j) {
                        
int zz = st.get(k);
                        
if (q && zz >= z) {
                            q 
= false;
                            ss.add(z);
                        }
                        ss.add(zz);
                    }
                }
                
if (q)
                    ss.add(z);
                
double p = a[i][j] * 1.0 / (nn - s);
                
double e = a[i][j] * 1.0 / ((1 - s * 1.0 / nn) * (1 - s * 1.0 / nn) * nn);
                e 
= e + get(ss) * p;
                res 
+= e;
            }
        }
        System.out.println(st);
        mem.put(st, res);
        
return res;
    }

}


posted on 2011-02-19 11:01 糯米 阅读(610) 评论(0)  编辑 收藏 引用 所属分类: POJ


只有注册用户登录后才能发表评论。
网站导航: 博客园   IT新闻   BlogJava   知识库   博问   管理